# Modify i-th next tensor values every time a value 1 appears in a tensor

## Issue

I have two tensors with the same size:

``````a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b = [0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1]
``````

Tensor a has three regions which are demarked by consecutive values: region 1 is `[1,2,3,4,5]`, region 2 is `[10,11,12,13]` and region 3 is `[20, 21, 22, 23, 24, 25, 26, 27, 28]`.

For each of those regions, I want to apply the following logic: if one of the values of b is 1, then the following i values are set to 0. If they are already 0, they continue as 0. After i values are changed, nothing happens until another value of b is 1. In that case, the next i values are forced to 0…

Some examples:

``````# i = 1

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 1, 0,  1,  0,  1,  0,  1,  0,  1,  0,  1,  0,  0,  0,  1]

# i = 2

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 1,  1,  0,  0,  0,  1,  0,  0,  1,  0,  0,  0,  0,  1]

# i = 4

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 0,  1,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  1]
``````

Not sure if this would help, but I was able to separate the regions into segments by doing:

``````a_shifted = tf.roll(a - 1, shift=-1, axis=0)
a_shifted_segs = tf.math.cumsum(tf.cast(a_shifted != a, dtype=tf.int64), exclusive=True)

# a_shifted_segs =
= [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]
``````

Do you know any way of doing this efficiently?

## Solution

Here is a pure `Tensorflow` approach, which will work in `Eager Execution` and `Graph` mode:

``````# copy, paste, acknowledge

import tensorflow as tf

def split_regions_and_modify(a, b, i):
indices = tf.squeeze(tf.where(a[:-1] != a[1:] - 1), axis=-1) + 1
row_splits = tf.cast(tf.cond(tf.not_equal(tf.shape(indices)[0], 0),
lambda: tf.concat([indices, [indices[-1] + (tf.cast(tf.shape(a), dtype=tf.int64)[0] - indices[-1])]], axis=0),
lambda: tf.shape(a)[0][None]), dtype=tf.int32)

def body(i, j, k, tensor, row_splits):
k = tf.cond(tf.equal(row_splits[k], j), lambda: tf.add(k, 1), lambda: k)
current_indices = tf.range(j + 1, tf.minimum(j + 1 + i, row_splits[k]), dtype=tf.int32)

tensor = tf.cond(tf.logical_and(tf.equal(tensor[j], 1), tf.not_equal(j,  row_splits[k])), lambda:
tf.tensor_scatter_nd_update(tensor, current_indices[..., None], tf.zeros_like(current_indices)), lambda: tensor)
return i, tf.add(j, 1), k, tensor, row_splits

j0 = tf.constant(0)
k0 = tf.constant(0)
c = lambda i, j0, k0, b, row_splits: tf.logical_and(tf.less(j0, tf.shape(b)[0]), tf.less(k0, tf.shape(row_splits)[0]))
_, _, _, output, _ = tf.while_loop(c, body, loop_vars=[i, j0, k0, b, row_splits])
return output
``````

Usage:

``````a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1])

split_regions_and_modify(a, b, 1)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 2)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 4)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)>
``````