# Behavior of Dataset.map in Tensorflow

## Issue

I’m trying to take variable length tensors and split them up into tensors of length 4, discarding any extra elements (if the length is not divisible by four).

I’ve therefore written the following function:

``````def batches_of_four(tokens):
token_length = tokens.shape

splits = token_length // 4

tokens = tokens[0 : splits * 4]

return tf.split(tokens, num_or_size_splits=splits)

dataset = tf.data.Dataset.from_tensor_slices(
tf.ragged.constant([[1, 2, 3, 4, 5], [4, 5, 6, 7]]))

print(batches_of_four(next(iter(dataset))))
``````

This produces the output `[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4], dtype=int32)>]`, as expected.

If I now run the same function using `Dataset.map`:

``````for item in dataset.map(batches_of_four):
print(item)
``````

I instead get the following error

``````    File "<ipython-input-173-a09c55117ea2>", line 5, in batches_of_four  *
splits = token_length // 4

TypeError: unsupported operand type(s) for //: 'NoneType' and 'int'
``````

I see that this is because `token_length` is `None`, but I don’t understand why. I assume this has something to do with graph vs eager execution, but the function works if I call it outside of `.map` even if I annotate it with `@tf.function`.

Why is the behavior different inside `.map`? (Also: is there any better way of writing the `batches_of_four` function?)

## Solution

You should use `tf.shape` to get the dynamic shape of a tensor in `graph` mode:

``````token_length = tf.shape(tokens)
``````

And another problem you have is using a scalar tensor as the number of splits in `graph` mode. That won’t work either.

Try this:

``````import tensorflow as tf

def body(i, m, n):
n = n.write(n.size(), m[i:i+chunk_size])

def split_data(data, chunk_size):
length = tf.shape(data)
x = data[:(length // chunk_size) * chunk_size]
ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
i0 = tf.constant(0)
c = lambda i, m, n: tf.less(i, tf.shape(x) - 1)
_, _, out = tf.while_loop(c, body, loop_vars=[i0, x, ta])
return out.stack()

chunk_size = 4

dataset = tf.data.Dataset.from_tensor_slices(
tf.ragged.constant([[1, 2, 3, 4, 5], [4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 8, 9]])).map(lambda x: split_data(x, 4)).flat_map(tf.data.Dataset.from_tensor_slices)

for item in dataset:
print(item)
``````
``````tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([5 6 7 8], shape=(4,), dtype=int32)
``````

And see my other answer here. 