Exemple #1
0
             loss = -elbo + 1e-5*reg_loss
             score = batch_dice(F.softmax(net.sample(data,mean=False)))
             running_train_loss.append(loss.item())
             running_train_score(score.item())
     
 epoch_train_loss,epoch_train_score = np.mean(running_train_loss),np.mean(running_train_score)
 print('Train loss : {} Dice score : {}'.format(epoch_train_loss,epoch_train_score)
 train_loss.append(epoch_train_loss)
 dice_score_train.append(epoch_train_score)
     
 epoch_val_loss,epoch_val_score = np.mean(running_val_loss),np.mean(running_val_score)
 print('Train loss : {} Dice score : {}'.format(epoch_val_loss,epoch_val_score)
 val_loss.append(epoch_val_loss)
 dice_score_val.append(epoch_val_score)
       
 checkpoint = { 'epoch': epoch +1,
               'valid_loss_min':epoch_val_loss,
               'state_dict':net.state_dict(),
               'optimizer':optimizer.state_dict(),
     
 }
 save_ckp(checkpoint, False,checkpoint_path,best_model_path)
  
 if epoch_val_loss <= valid_loss_min:
       print('Validation loss decreased ({:.6f} =======> {:.6f}). Saving model ...'.format(valid_loss_min,epoch_val_loss))
       
       save_ckp(checkpoint, True,checkpoint_path,best_model_path)
       valid_loss_min=epoch_val_loss
       
 time_passed = time.time() - started
 print('{:.0f}m {:.0f}s'.format(time_passed//60, time_passed%60))
            
    # end of validation
    loss_val /= len(test_loader)
    
    train_loss.append(loss_train)
    test_loss.append(loss_val)
    
    print('End of epoch ', epoch+1, ' , Train loss: ', loss_train, ', val loss: ', loss_val)   
    
    secheduler.step()
    
    # save best model checkpoint
    if loss_val < best_val_loss:
        best_val_loss = loss_val
        fname = 'model_dict.pth'
        torch.save(net.state_dict(), os.path.join(out_dir, fname))
        print('model saved at epoch: ', epoch+1)

print('Finished training')
# save loss curves        
plt.figure()
plt.plot(train_loss)
plt.title('train loss')
fname = os.path.join(out_dir,'loss_train.png')
plt.savefig(fname)
plt.close()

plt.figure()
plt.plot(test_loss)
plt.title('val loss')
fname = os.path.join(out_dir,'loss_val.png')
Exemple #3
0
def train(args):
    num_epoch = args.epoch
    learning_rate = args.learning_rate
    task_dir = args.task
    
    trainset = MedicalDataset(task_dir=task_dir, mode='train' )
    validset = MedicalDataset(task_dir=task_dir, mode='valid')

    model =  ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
    model.to(device)
    #summary(model, (1,320,320))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
    criterion = torch.nn.BCELoss()

    for epoch in range(num_epoch):
        model.train()
        while trainset.iteration < args.iteration:
            x, y = trainset.next()
            x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
            #print(x.size(), y.size())
            #output = torch.nn.Sigmoid()(model(x))
            model.forward(x,y,training=True)
            elbo = model.elbo(y)

            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
            loss = -elbo + 1e-5 * reg_loss
            #loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        trainset.iteration = 0

        model.eval()
        with torch.no_grad():
            while validset.iteration < args.test_iteration:
                x, y = validset.next()
                x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
                #output = torch.nn.Sigmoid()(model(x, y))
                model.forward(x,y,training=True)
                elbo = model.elbo(y)

                reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
                valid_loss = -elbo + 1e-5 * reg_loss
            validset.iteration = 0
                
        print('Epoch: {}, elbo: {:.4f}, regloss: {:.4f}, loss: {:.4f}, valid loss: {:.4f}'.format(epoch+1, elbo.item(), reg_loss.item(), loss.item(), valid_loss.item()))
        """
        #Logger
         # 1. Log scalar values (scalar summary)
        info = { 'loss': loss.item(), 'accuracy': valid_loss.item() }

        for tag, value in info.items():
            Logger.scalar_summary(tag, value, epoch+1)

        # 2. Log values and gradients of the parameters (histogram summary)
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            Logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1)
            Logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1)
        """
    torch.save(model.state_dict(), './save/'+trainset.task_dir+'model.pth')