I am very confused by how Pytorch deals with one-hot vectors. In this tutorial, the neural network will generate a one-hot vector as its output. As far as I understand, the schematic structure of the neural network in the tutorial should be like:
labels are not in one-hot vector format. I get the following
print(labels.size()) print(outputs.size()) output>>> torch.Size() output>>> torch.Size([4, 10])
Miraculously, I they pass the
criterion=CrossEntropyLoss(), there’s no error at all.
loss = criterion(outputs, labels) # How come it has no error?
Maybe pytorch automatically convert the
labels to one-hot vector form. So, I try to convert labels to one-hot vector before passing it to the loss function.
def to_one_hot_vector(num_class, label): b = np.zeros((label.shape, num_class)) b[np.arange(label.shape), label] = 1 return b labels_one_hot = to_one_hot_vector(10,labels) labels_one_hot = torch.Tensor(labels_one_hot) labels_one_hot = labels_one_hot.type(torch.LongTensor) loss = criterion(outputs, labels_one_hot) # Now it gives me error
However, I got the following error
RuntimeError: multi-target not supported at
So, one-hot vectors are not supported in
Pytorch? How does
Pytorch calculates the
cross entropy for the two tensor
outputs = [1,0,0],[0,0,1] and
labels = [0,2] ? It doesn’t make sense to me at all at the moment.
PyTorch states in its documentation for
This criterion expects a class index (0 to C-1) as the target for each value of a 1D tensor of size minibatch
In other words, it has your
to_one_hot_vector function conceptually built in
CEL and does not expose the one-hot API. Notice that one-hot vectors are memory inefficient compared to storing class labels.
If you are given one-hot vectors and need to go to class labels format (for instance to be compatible with
CEL), you can use
argmax like below:
import torch labels = torch.tensor([1, 2, 3, 5]) one_hot = torch.zeros(4, 6) one_hot[torch.arange(4), labels] = 1 reverted = torch.argmax(one_hot, dim=1) assert (labels == reverted).all().item()
Answered By – Jatentaki