retrieving the next element from tf.data.Dataset in tensorflow 2.0 beta

Issue

Before tensorflow 2.0-beta, to retrieve the first element from tf.data.Dataset, we may use a iterator as shown below:

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
    # 1.0 will be printed.
    print (sess.run(iterator.get_next()))

In tensorflow 2.0-beta, it seems that the above one-shot-iterator is now deprecated. To print out the entire elements we may use the following for approach.

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])

for data in train_dataset:
    # 1.0, 2.0, 3.0, and 4.0 will be printed.
    print (data.numpy())

However, if we only want to retrieve exactly one element from tf.data.Dataset, then how can we do with tensorflow 2.0 beta? It seems that next(train_dataset) is not supported. It could be done easily with the old one shot iterator as shown above, but it’s not very obvious with the new for based approach.

Any suggestion is welcomed.

Solution

You can .take(1) from the dataset:

for elem in train_dataset.take(1):
  print (elem.numpy())

Answered By – Stewart_R

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published