Tensorflow filter out tensors without zero

Issue

I have batch tensors of X and Y like this

X = tf.constant([[[1,-2], [2,0],  [-2,2], [4,-1]],
                 [[3,1],  [4,1],  [**0**,1], [-5,3]],
                 [[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [43], [2]])

X in reality has a dimension of TensorShape([512, 30, 57]).

I want to filter out elements at dimension 0 that have a zero at any of the first elements at dimension 2 (check the highlighted zero above).

X = tf.constant([[[1,-2], [2,0],  [-2,2], [4,-1]],
                 [[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [2]])

For now, I have the following code

idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
X_clean = [X[x, :, :] for x in idx]
X_clean = tensorflow.stack(X_clean)
Y_clean = tensorflow.stack([Y[x] for x in idx])

This is super slow and for each iteration, it takes like 2 seconds. How can I make this work faster?

Solution

You can achieve a more efficient solution using tf.where, tf.reduce_all and tf.gather:

# getting the index of the valid elements batch wise
# X[...,0]!=0 checks that the first element in the last dimension is not 0 
# reduce_all cheks that this is true for every element along dimension 1 
# where gives the index of those valid elements
valid_element_idxs = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_element_idxs)
Y_clean = tf.gather(Y, valid_element_idxs)

Comparing your approach, and this one with %timeit on the 2 small Tensors you gave as an example:

>>> %timeit list_comp(X,Y)
2.82 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit tf_native(X,Y)
263 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

You can squeeze a bit of performance using tf.function:

>>> %timeit tf_native_decorated(X,Y)
206 µs ± 6.31 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Functions definition for reference:

def list_comp(X,Y):
    idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
    X_clean = [X[x, :, :] for x in idx]
    X_clean = tf.stack(X_clean)
    Y_clean = tf.stack([Y[x] for x in idx])
    return X_clean, Y_clean

def tf_native(X,Y):
    valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
    X_clean = tf.gather(X, valid_elements_idx)
    Y_clean = tf.gather(Y, valid_elements_idx)
    return X_clean, Y_clean

@tf.function
def tf_native_decorated(X,Y):
    valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
    X_clean = tf.gather(X, valid_elements_idx)
    Y_clean = tf.gather(Y, valid_elements_idx)
    return X_clean, Y_clean

Answered By – Lescurel

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published