# Tensorflow filter out tensors without zero

## Issue

I have batch tensors of `X` and `Y` like this

``````X = tf.constant([[[1,-2], [2,0],  [-2,2], [4,-1]],
[[3,1],  [4,1],  [**0**,1], [-5,3]],
[[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [43], [2]])
``````

`X` in reality has a dimension of `TensorShape([512, 30, 57])`.

I want to filter out elements at dimension 0 that have a zero at any of the first elements at dimension 2 (check the highlighted zero above).

``````X = tf.constant([[[1,-2], [2,0],  [-2,2], [4,-1]],
[[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [2]])
``````

For now, I have the following code

``````idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
X_clean = [X[x, :, :] for x in idx]
X_clean = tensorflow.stack(X_clean)
Y_clean = tensorflow.stack([Y[x] for x in idx])
``````

This is super slow and for each iteration, it takes like 2 seconds. How can I make this work faster?

## Solution

You can achieve a more efficient solution using `tf.where`, `tf.reduce_all` and `tf.gather`:

``````# getting the index of the valid elements batch wise
# X[...,0]!=0 checks that the first element in the last dimension is not 0
# reduce_all cheks that this is true for every element along dimension 1
# where gives the index of those valid elements
valid_element_idxs = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_element_idxs)
Y_clean = tf.gather(Y, valid_element_idxs)
``````

Comparing your approach, and this one with %timeit on the 2 small Tensors you gave as an example:

``````>>> %timeit list_comp(X,Y)
2.82 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit tf_native(X,Y)
263 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
``````

You can squeeze a bit of performance using `tf.function`:

``````>>> %timeit tf_native_decorated(X,Y)
206 µs ± 6.31 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
``````

Functions definition for reference:

``````def list_comp(X,Y):
idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
X_clean = [X[x, :, :] for x in idx]
X_clean = tf.stack(X_clean)
Y_clean = tf.stack([Y[x] for x in idx])
return X_clean, Y_clean

def tf_native(X,Y):
valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_elements_idx)
Y_clean = tf.gather(Y, valid_elements_idx)
return X_clean, Y_clean

@tf.function
def tf_native_decorated(X,Y):
valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_elements_idx)
Y_clean = tf.gather(Y, valid_elements_idx)
return X_clean, Y_clean
``````