Issue
My network has multiple inputs where one of those inputs is an index that is used in the network to index into other tensors.
I am having issues using the tensor as an index.
class MemoryLayer(tf.keras.layers.Layer):
def __init__(self, memory_size, k, **kwargs):
super().__init__(kwargs)
self.memory_size = memory_size
self.k = k
def build(self,input_shape):
# Set up the memory_var
# Shape of input is [(1,3,6), (1,3)]
def call(self, input):
for i in range(3):
statement = input[0][0,i]
cluster = input[1][0,i]
old_sub_mem = self.memory_var[cluster, :-1] #Error here
# Here should be a bunch of stuff I removed because its not relevant
return tf.expand_dims(self.memory_var, axis=0)
I get a TypeError
saying that <tf.Tensor 'memory_layer_19/strided_slice_2:0' shape=() dtype=float32>
isn’t a valid index. I tried calling .numpy()
on input[1]
but this doesn’t work as the tensor has no shape. From the data I input cluster
should be a single number.
Solution
By default, input to layers are tf.float32
. However, to index a tensor, you need integers. You can either cast the inputs of your layer to integers, or you can specify that the input of that layer should be of the integer type.
With casting
cluster = tf.cast(input[1], dtype=tf.int32)[0,i]
old_sub_mem = self.memory_var[cluster, :-1]
Specifying the type of the input
I use the functional API in that example:
inp_statement = tf.keras.Input(shape=(3,6))
inp_cluster = tf.keras.Input(shape=(3,), dtype=tf.int32)
memory = MemoryLayer(memory_size=10)([inp1,inp2])
Note: I don’t exactly understand what you are trying to achieve, but this for loop could probably be optimized out with a call to tf.gather
or tf.gather_nd
.
Answered By – Lescurel
This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0