Filtering tf data ZipDataset

Issue

I have a dataset with images loaded with image_dataset_from_directory with 4 classes containing tuples of (image, label).
What i did was create a column 'id' with the same size as the dataset, convert this column to a tf data dataset and concatenate the 2 datasets using :

dataset = tf.data.Dataset.zip((dataset, client_id))

Resulting to a dataset with signature :

<ZipDataset element_spec=((TensorSpec(shape=(128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.int32, name=None)), TensorSpec(shape=(), dtype=tf.int64, name=None))>

What i would like to do now is be able to filter this concatenated dataset whenever i want based on the client id value. What i tried is :

dataset = dataset.filter( x : x[1]==15)

but i get :

TypeError: 'ZipDataset' object is not subscriptable

However this :

for x in dataset.take(1):
  print(x[1])

prints the client id correctly :

tf.Tensor(15, shape=(), dtype=int64)

How could this be done?

Solution

x[0] is a tuple and x[1] contains the client-ids when iterating over your dataset. Maybe try:

import tensorflow as tf

a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
dataset = tf.data.Dataset.zip((tf.data.Dataset.zip((a, b)), tf.data.Dataset.range(4, 7)))
dataset = dataset.filter(lambda x, y: y==4)
for x, y in dataset:
  print(x, y)
(<tf.Tensor: shape=(), dtype=int64, numpy=1>, <tf.Tensor: shape=(), dtype=int64, numpy=4>) tf.Tensor(4, shape=(), dtype=int64)

Note that y refers to the cliend ids in this case.

Answered By – AloneTogether

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published