If, else statement returns both in tf.function

Issue

I want to make a function which can handle both floats and vectors as input using Tensorflow in Python. I defined the following function:

def g(t):
    if tf.rank(t) == 0:
        print('Rank=0')
        return tf.math.reduce_sum(tf.math.exp(t))
    else:
        print('Rank=higher')
        return tf.math.reduce_sum(tf.math.exp(t),1)

However, I want to call the function in another tf.function. As a test I made the following function :

@tf.function
def Test(t):
    return g(t)

Calling g(0.5) gives

Rank=0
Out[218]: <tf.Tensor: shape=(), dtype=float32, numpy=2.7182817>

Calling Test(0.5) gives:

rank=0
rank=higher
Traceback (most recent call last):

  Input In [219] in <cell line: 1>
    Test(0.5)

  File ~\Anaconda3\lib\site-packages\tensorflow\python\util\traceback_utils.py:153 in error_handler
    raise e.with_traceback(filtered_tb) from None

  File ~\AppData\Local\Temp\__autograph_generated_filegb02ol08.py:12 in tf__Test
    retval_ = ag__.converted_call(ag__.ld(gn), (ag__.ld(t),), None, fscope)

  File ~\AppData\Local\Temp\__autograph_generated_filegnzfdu42.py:37 in tf__gn
    ag__.if_stmt(ag__.converted_call(ag__.ld(int), (ag__.converted_call(ag__.ld(tf).rank, (ag__.ld(t),), None, fscope),), None, fscope) == 0, if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)

  File ~\AppData\Local\Temp\__autograph_generated_filegnzfdu42.py:33 in else_body
    retval_ = ag__.ld(V0) + ag__.ld(labda) * ag__.ld(theta) * ag__.converted_call(ag__.ld(tf).math.reduce_sum, (ag__.ld(c) / ag__.ld(gamma) * (1 - ag__.converted_call(ag__.ld(tf).math.exp, (-ag__.ld(gamma) * ag__.ld(t),), None, fscope)), 1), None, fscope)

ValueError: in user code:

    File "C:\Users\jgrou\AppData\Local\Temp\ipykernel_11872\3135092574.py", line 11, in Test  *
        return gn(t)
    File "C:\Users\jgrou\AppData\Local\Temp\ipykernel_11872\3135092574.py", line 7, in gn  *
        return V0 + labda * theta * tf.math.reduce_sum(c / gamma * (1 - tf.math.exp(-gamma * t)),1)

    ValueError: Invalid reduction dimension 1 for input with 1 dimensions. for '{{node cond/Sum}} = Sum[T=DT_FLOAT, Tidx=DT_INT32, keep_dims=false](cond/mul_1, cond/Sum/reduction_indices)' with input shapes: [1], [] and with computed input tensors: input[1] = <1>.

Why do both arguments of the if-else statement get called in the tf.function? And how can I make the function g work inside a tf.function?

Solution

It looks like someone brought this behavior up in a fairly recent Github Issue. Highlighting the response from one of the Tensorflow developers before closing the issue:

The cause of this problem is due to the behavior of condition tracing in TensorFlow: the same input is applied to both true and false sides for graph tracing, when the condition is based on a non-static value (i.e. tf.rank(v) == 2).

There are two viable solutions.

Use Constant Value

If you use tf.get_static_value (details here) to return the constant value of the 0-D Tensor returned by tf.rank, it prevents the condition tracing, as it evaluates the Tensor (converts it to an int, float, numpy array, etc. depending on the shape and type).

def g(t):
    if tf.get_static_value(tf.rank(t)) == 0:
        print('Rank=0')
        return tf.math.reduce_sum(tf.math.exp(t))
    else:
        print('Rank=higher')
        return tf.math.reduce_sum(tf.math.exp(t), 1)

This returns the expected results:

Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)

Direct Shape Evaluation

Rather than using tf.rank, evaluate the shape directly, which also requires converting any non-Tensor inputs to a Tensor:

def g(t):
    if not isinstance(t, tf.Tensor):
        t = tf.convert_to_tensor(t)
    if t.shape.ndims == 0:
        print('Rank=0')
        return tf.math.reduce_sum(tf.math.exp(t))
    else:
        print('Rank=higher')
        return tf.math.reduce_sum(tf.math.exp(t), 1)

This implementation also yields the expected results:

Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)

Answered By – danielcahall

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published