How to deal with NotImplementedError in training Unet model?

Issue

def train_fn(data_loader, model, optimizer):

model.train()
total_loss = 0.0

for images, masks in tqdm(data_loader):

  images = images.to(DEVICE)
  masks = masks.to(DEVICE)

  optimizer.zero_grad()
  logits, loss = model(images,masks)
  loss.backward()
  optimizer.step()

  total_loss += loss.item()



return total_loss/ len(data_loader)


def eval_fn(data_loader, model):

model.eval()
total_loss = 0.0

with torch.no_grad():

  for images, masks in tqdm(data_loader):

    images = images.to(DEVICE)
    masks = masks.to(DEVICE)

    logits, loss = model(images,masks)


    total_loss += loss.item()


return total_loss/ len(data_loader)

optimizer = torch.optim.Adam(model.parameters(), lr = LR)

best_valid_loss = np.Inf

for i in range(EPOCHS):


train_loss = train_fn(trainloader, model, optimizer)
valid_loss = eval_fn(validloader, model)

if valid_loss < best_valid_loss:
  torch.save(model.state_dict(), 'best_model.pt')
  print("SAVED_MODEL")
  best_valid_loss = valid_loss

print(f"Epoch : {i+1} Train_loss: {train_loss} Valid_loss: {valid_loss}")

I get the following error when I try to train the model:

0%| | 0/15 [00:00<?, ?it/s]

NotImplementedError Traceback (most recent call last)
in ()
4
5
—-> 6 train_loss = train_fn(trainloader, model, optimizer)
7 valid_loss = eval_fn(validloader, model)
8

2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)
199 registered hooks while the latter silently ignores them.
200 """
–> 201 # raise NotImplementedError
202
203

NotImplementedError:

How do I deal with this?

Solution

Looking at the link you provided in the comment, your model definition looks like this:

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

    def forward(self, images, masks = None):
      logits = self.arc(images)

      if masks != None:
        loss1 = DiceLoss(mode = 'binary')(logits, masks)
        loss2 = nn.BCEWithLogitsLoss()(logits,masks)
        return logits, loss1 + loss2

      return logits

If you look close, you’ll see forward() has an erratic extra indentation, making it an internal function inside __init__() rather than a method of a SegmentationModel. Shift it a bit to left, and it should work fine:

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

  def forward(self, images, masks = None):
    logits = self.arc(images)

    if masks != None:
      loss1 = DiceLoss(mode = 'binary')(logits, masks)
      loss2 = nn.BCEWithLogitsLoss()(logits,masks)
      return logits, loss1 + loss2

    return logits

Answered By – dx2-66

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published