Tensorflow dataset pipeline with specific classes


I would like to use a dataset pipeline with specific class indexes.

  • For example:

if I use CIFAR-10 Dataset. I can load the dataset as follows:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

Which load all the class labels (10 Classes). I can create a pipeline using the following code:

train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices(x_test,y_test)).batch(64)

This works well for the training Keras model.

  • Now I want to create a pipeline with a few samples (Instead of using all 10 class samples maybe use only 5 samples). Is there any way to make a pipeline like this?


You can use tf.data.Dataset.filter:

import tensorflow as tf

class_indexes_to_keep = tf.constant([0, 3, 4, 6, 8], dtype=tf.int64)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

y_train = y_train.astype(int)
y_test = y_test.astype(int)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).filter(lambda x, y: tf.reduce_any(y == class_indexes_to_keep)).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test,y_test)).filter(lambda x, y: tf.reduce_any(y == class_indexes_to_keep)).batch(64)

To convert to categorical labels, you could try:

import tensorflow as tf

one_hot_encode = tf.keras.utils.to_categorical(tf.range(10, dtype=tf.int64), num_classes=10)
class_indexes_to_keep = tf.gather(one_hot_encode, tf.constant([0, 3, 4, 6, 8], dtype=tf.int64))

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = y_train.astype(int)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).map(lambda x, y: (x, tf.one_hot(y, 10)[0]))
train_dataset = train_dataset.filter(lambda x, y: tf.reduce_any(tf.reduce_all(y == class_indexes_to_keep, axis=-1))).batch(64)

Answered By – AloneTogether

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published