Convert a tensorflow script to pytorch (TransformedDistribution)


I am trying to rewrite a tensorflow script in pytorch. I have a problem finding the equivalent part in torch for the following line from this script:

import tensorflow_probability as tfp
tfd = tfp.distributions
a_distribution = tfd.TransformedDistribution(
        distribution=tfd.Normal(loc=0.0, scale=1.0),
            tfp.bijectors.AffineScalar(shift=mean, scale=std),

In particular, I have a huge problem for replacing the tfp.bijectors.Chain component.
I wrote the following lines in torch, but I am wondering whether these lines in pytorch compatible with the above tensorflow code and whether I can specify the batch_shape somewhere?

base_distribution = torch.normal(0.0, 1.0)
transforms = torch.distributions.transforms.ComposeTransform([torch.distributions.transforms.AffineTransform(loc=self._action_means, scale=self._action_mag, event_dim=mean.shape[-1]), torch.nn.Tanh(),torch.distributions.transforms.AffineTransform(loc=mean, scale=std, event_dim=mean.shape[-1])])
a_distribution = torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms)

Any solution?


In Pytorch, the base distribution class Distribution expects both a batch_shape and a event_shape parameter. Now notice that the subclass TransformedDistribution does not take such parameters (src code). That’s because they are inferred from the base distribution class provided on initialization: see here and here.

You already found out about AffineTransform and ComposeTransform. Keep in mind you must stick with classes from the torch.distributions.

  • This holds for torch.normal which should be replaced with Normal. With this class, the shape is inferred from the provided loc and scale tensors.

  • And nn.Tanh which should be replaced with TanhTransform.

Here is a minimal example using your transformation pipeline:

from torch.distributions.normal import Normal
from torch.distributions import transforms as tT
from torch.distributions.transformed_distribution import TransformedDistribution
mean = torch.rand(2,2)
std = 1
_action_means, _action_mag = 0, 1
Distribution definition:
a_distribution = TransformedDistribution(
    base_distribution=Normal(loc=torch.full_like(mean, 0), 
                             scale=torch.full_like(mean, 1)), 
        tT.AffineTransform(loc=_action_means, scale=_action_mag, event_dim=event_dim), 
        tT.AffineTransform(loc=mean, scale=std, event_dim=event_dim)]))

Answered By – Ivan

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published