converting tf.data.Dataset.from_tensor_slices to pytorch

Issue

I am trying to convert this model from tensorflow to pytorch. Unfortunately, I don’t know tensorflow very well. I have a problem transferring ti data_loader from here and in general converting this function to pytorch

def _components_train_step(self, importance_weights, old_means, old_chol_precisions):
    for i in range(self._model.num_components):
        dt = (self._train_contexts, importance_weights[:, i], old_means, old_chol_precisions)
        data = tf.data.Dataset.from_tensor_slices(dt)
        data = data.shuffle(self._train_contexts.shape[0]).batch(self.c.components_batch_size)

        for context_batch, iw_batch, old_means_batch, old_chol_precisions_batch in data:
            iw_batch = iw_batch / tf.reduce_sum(iw_batch)
            with tf.GradientTape() as tape:
                 samples = self._model.components[i].sample(context_batch)
                 losses = - tf.squeeze(self._dre(tf.concat([context_batch, samples], axis=-1)))
                 kls = self._model.components[i].kls_other_chol_inv(context_batch, old_means_batch[:, i],
                                                                  old_chol_precisions_batch[:, i])
                loss = tf.reduce_mean(iw_batch * (losses + kls))
            gradients = tape.gradient(loss, self._model.components[i].trainable_variables)
            self._c_opts[i].apply_gradients(zip(gradients, self._model.components[i].trainable_variables))

I implemented this function as follows

def _components_train_step(self, importance_weights, old_means, old_chol_precisions):
        self._c_opts = [ torch.optim.Adam(self._model.components[i].trainable_variables, lr=self.c.components_learning_rate, betas=(0.5, 0.999)) for i in self._model.components]
        
        for i in range(self._model.num_components):
            dataset = torch.utils.data.TensorDataset(self._train_contexts, importance_weights[:, i], old_means, old_chol_precisions)
            loader = torch.utils.data.DataLoader( dataset, shuffle = True, batch_size=self.c.components_batch_siz)

            for batch_idx, (context_batch, iw_batch, old_means_batch, old_chol_precisions_batch) in enumerate(loader):
                iw_batch = iw_batch / torch.sum(iw_batch)
                samples = self._model.components[i].sample(context_batch)
                losses = - torch.squeeze(self._dre(torch.cat([context_batch, samples], dim=-1)))
                kls = self._model.components[i].kls_other_chol_inv(context_batch, old_means_batch[:, i],
                                                                  old_chol_precisions_batch[:, i])
                loss = torch.mean(iw_batch * (losses + kls))
                loss.backward()
                self._c_opts[i].zero_grad()
                self._c_opts[i].step()

Any suggestion or help?

Solution

I believe you can achieve a comparable result to tf.data.from_tensor_slices using PyTorch’s data.TensorDataset which expects a tuple of tensors as input. This has the effect of zipping the different elements into a single dataset yielding tuple of the same length as there are elements.

Here is a minimal example:

feats = torch.tensor([[[1, 3], [2, 3]], [[2, 1], [1, 2]], [[3, 3], [3, 2]]])
tf_feats = tf.convert_to_tensor(feats.numpy())

labels = torch.tensor([[10, 10], [20, 20], [10, 20]])
tf_labels = tf.convert_to_tensor(labels.numpy())

Using TensorFlow:

>>> dataset = Dataset.from_tensor_slices((tf_feats, tf_labels))
>>> for x in dataset.as_numpy_iterator():
...     print(x)
(array([[1, 3],
        [2, 3]]), array([10, 10]))
(array([[2, 1],
        [1, 2]]), array([20, 20]))
(array([[3, 3],
        [3, 2]]), array([10, 20]))

Using PyTorch:

 >>> dataset = data.TensorDataset(feats, labels)
 >>> for x in dataset:
 ...     print(x)
(tensor([[1, 3],
         [2, 3]]), tensor([10, 10]))
(tensor([[2, 1],
         [1, 2]]), tensor([20, 20]))
(tensor([[3, 3],
         [3, 2]]), tensor([10, 20]))

Answered By – Ivan

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published