How to select rows of tensor based on condition (tensorflow)

Issue

Only using tensorflow, how can I select rows of a tensor that satisfy a condition?

Example tensor x:

<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[0, 1, 2],
       [1, 1, 2],
       [0, 1, 4]], dtype=int32)>

I’d like to create a new Tensor that only includes those rows of x where the first row element equals 0

Solution

import tensorflow as tf

x = tf.constant([[0, 1, 2], [1, 1, 2], [0, 1, 4]])
x = tf.constant([i for i in x.numpy() if i[0] == 0)

Or only with tensorflow:

a = tf.constant([[0, 1, 2], [1, 1, 2], [0, 1, 4]])
mask = tf.where(a[:,0] == 0, True, False)
a = tf.boolean_mask(a, mask)

Answered By – Djinn

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published