Modify i-th next tensor values every time a value 1 appears in a tensor

Issue

I have two tensors with the same size:

a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b = [0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1]

Tensor a has three regions which are demarked by consecutive values: region 1 is [1,2,3,4,5], region 2 is [10,11,12,13] and region 3 is [20, 21, 22, 23, 24, 25, 26, 27, 28].

For each of those regions, I want to apply the following logic: if one of the values of b is 1, then the following i values are set to 0. If they are already 0, they continue as 0. After i values are changed, nothing happens until another value of b is 1. In that case, the next i values are forced to 0…

Some examples:

# i = 1

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 1, 0,  1,  0,  1,  0,  1,  0,  1,  0,  1,  0,  0,  0,  1]


# i = 2

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 1,  1,  0,  0,  0,  1,  0,  0,  1,  0,  0,  0,  0,  1]


# i = 4

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 0,  1,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  1]

Not sure if this would help, but I was able to separate the regions into segments by doing:

a_shifted = tf.roll(a - 1, shift=-1, axis=0)
a_shifted_segs = tf.math.cumsum(tf.cast(a_shifted != a, dtype=tf.int64), exclusive=True)

# a_shifted_segs = 
= [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]

Do you know any way of doing this efficiently?

Solution

Here is a pure Tensorflow approach, which will work in Eager Execution and Graph mode:

# copy, paste, acknowledge

import tensorflow as tf

def split_regions_and_modify(a, b, i):
  indices = tf.squeeze(tf.where(a[:-1] != a[1:] - 1), axis=-1) + 1
  row_splits = tf.cast(tf.cond(tf.not_equal(tf.shape(indices)[0], 0), 
                    lambda: tf.concat([indices, [indices[-1] + (tf.cast(tf.shape(a), dtype=tf.int64)[0] - indices[-1])]], axis=0), 
        lambda: tf.shape(a)[0][None]), dtype=tf.int32)

  def body(i, j, k, tensor, row_splits):
    k = tf.cond(tf.equal(row_splits[k], j), lambda: tf.add(k, 1), lambda: k)
    current_indices = tf.range(j + 1, tf.minimum(j + 1 + i, row_splits[k]), dtype=tf.int32)

    tensor = tf.cond(tf.logical_and(tf.equal(tensor[j], 1), tf.not_equal(j,  row_splits[k])), lambda: 
                  tf.tensor_scatter_nd_update(tensor, current_indices[..., None], tf.zeros_like(current_indices)), lambda: tensor)
    return i, tf.add(j, 1), k, tensor, row_splits 

  j0 = tf.constant(0)
  k0 = tf.constant(0)
  c = lambda i, j0, k0, b, row_splits: tf.logical_and(tf.less(j0, tf.shape(b)[0]), tf.less(k0, tf.shape(row_splits)[0]))
  _, _, _, output, _ = tf.while_loop(c, body, loop_vars=[i, j0, k0, b, row_splits])
  return output

Usage:

a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1])

split_regions_and_modify(a, b, 1)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 2)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 4)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)>

Answered By – AloneTogether

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published