tf.data.Dataset apply() doesn't update dataset

Issue

I’m loading a dataset of images with image_dataset_from_directory and it gives me a PrefetchDataset with my images and their associated label one-hot encoded.

In order to build a binary image classifier, I want to transform my PrefetchDataset labels to know if an image is a photo or somethings else.

Here’s how I wrote it:

batch_size = 32
img_height = 250
img_width = 250

train_ds = image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  color_mode="rgb",
  subset="training",
  seed=69,
  crop_to_aspect_ratio=False,
  image_size=(img_height, img_width),
  batch_size=batch_size)

class_names = train_ds.class_names
# ['Painting', 'Photo', 'Schematics', 'Sketch', 'Text'] in my case

# Convert label to 1 is a photo or else 0
i = 1 # class_names.index('Photo')

def is_photo(batch):
    for images, labels in batch:
        bool_labels = tf.constant([int(l == 1) for l in labels],
                                  dtype=np.int32)
        labels = bool_labels
    return batch

new_train_ds = train_ds.apply(is_photo)

My problem is that the new_train_ds doesn’t defers from train_ds which leads me to thinks there must be an issue with the apply method.
I also checked bool_labels and it works just fine.

Does anyone have an idea on how to solve this issue.

Solution

Maybe try something like this:

train_ds = train_ds.map(lambda x, y: (x, tf.cast(y == 1, dtype=tf.int64)))

Answered By – AloneTogether

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published