Convert a tensorflow script to pytorch (TransformedDistribution)

Issue

I am trying to rewrite a tensorflow script in pytorch. I have a problem finding the equivalent part in torch for the following line from this script:

import tensorflow_probability as tfp
tfd = tfp.distributions
a_distribution = tfd.TransformedDistribution(
        distribution=tfd.Normal(loc=0.0, scale=1.0),
        bijector=tfp.bijectors.Chain([
            tfp.bijectors.AffineScalar(shift=self._means,
                                       scale=self._mags),
            tfp.bijectors.Tanh(),
            tfp.bijectors.AffineScalar(shift=mean, scale=std),
        ]),
        event_shape=[mean.shape[-1]],
        batch_shape=[mean.shape[0]])

In particular, I have a huge problem for replacing the tfp.bijectors.Chain component.
I wrote the following lines in torch, but I am wondering whether these lines in pytorch compatible with the above tensorflow code and whether I can specify the batch_shape somewhere?

base_distribution = torch.normal(0.0, 1.0)
transforms = torch.distributions.transforms.ComposeTransform([torch.distributions.transforms.AffineTransform(loc=self._action_means, scale=self._action_mag, event_dim=mean.shape[-1]), torch.nn.Tanh(),torch.distributions.transforms.AffineTransform(loc=mean, scale=std, event_dim=mean.shape[-1])])
a_distribution = torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms)
 

Any solution?

Solution

In Pytorch, the base distribution class Distribution expects both a batch_shape and a event_shape parameter. Now notice that the subclass TransformedDistribution does not take such parameters (src code). That’s because they are inferred from the base distribution class provided on initialization: see here and here.

You already found out about AffineTransform and ComposeTransform. Keep in mind you must stick with classes from the torch.distributions.

  • This holds for torch.normal which should be replaced with Normal. With this class, the shape is inferred from the provided loc and scale tensors.

  • And nn.Tanh which should be replaced with TanhTransform.


Here is a minimal example using your transformation pipeline:

Imports:
from torch.distributions.normal import Normal
from torch.distributions import transforms as tT
from torch.distributions.transformed_distribution import TransformedDistribution
Parameters:
mean = torch.rand(2,2)
std = 1
_action_means, _action_mag = 0, 1
event_dim=mean.shape[-1]
Distribution definition:
a_distribution = TransformedDistribution(
    base_distribution=Normal(loc=torch.full_like(mean, 0), 
                             scale=torch.full_like(mean, 1)), 
    transforms=tT.ComposeTransform([
        tT.AffineTransform(loc=_action_means, scale=_action_mag, event_dim=event_dim), 
        tT.TanhTransform(),
        tT.AffineTransform(loc=mean, scale=std, event_dim=event_dim)]))

Answered By – Ivan

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published