Unable to get log_prob from a TransformedDistribution in tensorflow

Issue

I was following along a tutorial regarding Transformed Distribution. We could specify the batch_shape to [2] and event_shape to [4] in previous version of tensorflow TransformedDistribution but we can’t now. I am wondering if how can we make the last code work without going back to the previous version of Tensorflow?

The error raised was:

ValueError: `event_ndims must be at least 0. Saw: 1

Code:

# Parameters
n = 10000

loc = 0

scale = 0.5

# Normal distribution
normal = tfd.Normal(loc=loc, scale=scale)

# Set a scaling lower triangular matrix
tril = tf.random.normal((2,4,4))

scale_low_tri = tf.linalg.LinearOperatorLowerTriangular(tril)

# Define scale linear operator
scale_lin_op = tfb.ScaleMatvecLinearOperator(scale_low_tri)

# Define scale linear operator transformed distribution with a batch and event shape
mvn = tfd.TransformedDistribution(distribution=normal, bijector=scale_lin_op)

xn = normal.sample((n,2,4))

mvn2.log_prob(xn)

Solution

I believe in version 0.12 you need to use tfd.Sample:

mvn = tfd.TransformedDistribution(distribution=tfd.Sample(
                                  tfd.Normal(loc=[loc, loc], scale=[scale, scale]),
                                  sample_shape=[4]), # --> event shape
                                  bijector=scale_lin_op)

mvn.log_prob(xn)

Output:

<tf.Tensor: shape=(10000, 2), dtype=float32, numpy=
array([[-4.5943561e+00, -6.5238861e+01],
       [-8.6548815e+00, -2.1378198e+05],
       [-3.1688419e+01, -3.3126004e+04],
       ...,
       [-1.8664089e+00, -3.8012810e+03],
       [-1.8821844e+01, -1.3414998e+04],
       [-2.3645339e+00, -2.4178730e+05]], dtype=float32)>

Answered By – Frightera

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published