Can I apply tf.map_fn(…) to multiple inputs/outputs?

Issue

a = tf.constant([[1,2,3],[4,5,6]])
b = tf.constant([True, False], dtype=tf.bool)

a.eval()
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
b.eval()
array([ True, False], dtype=bool)

I want to apply a functions to the inputs above, a, and b using tf.map_fn. It will input both [1,2,3], and True and output similar values.

Let’s say out function is simply the identity: lambda(x,y): x,y so, given an input of [1,2,3], True, it will output those identical tensors.

I know how to use tf.map_fn(...) with one variable, but not with two. And in this case I have mixed data types (int32 and bool) so I can’t simply concatenate the tensors and split them after the call.

Can I use tf.map_fn(...) with multiple inputs/outputs of different data types?

Solution

Figured it out. You have to define the data types for each tensor in dtype for each of the different tensors, then you can pass the tensors as a tuple, your map function receives a tuple of inputs, and map_fn returns back back a tuple.

Example that works:

a = tf.constant([[1,2,3],[4,5,6]])
b = tf.constant([True, False], dtype=tf.bool)

c = tf.map_fn(lambda x: (x[0], x[1]), (a,b), dtype=(tf.int32, tf.bool))

c[0].eval()
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
c[1].eval()
array([ True, False], dtype=bool)

Answered By – David Parks

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published