# network net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0) net.cuda() # optimizer optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=l2_reg) secheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_every, gamma=lr_decay) # logging train_loss = [] test_loss = [] best_val_loss = 999.0 for epoch in range(epochs): net.train() loss_train = 0 loss_segmentation = 0 # training loop for step, (patch, mask, _) in enumerate(train_loader): patch = patch.cuda() mask = mask.cuda() mask = torch.unsqueeze(mask,1) net.forward(patch, mask, training=True) elbo = net.elbo(mask) loss = -elbo optimizer.zero_grad() loss.backward() optimizer.step() loss_train += loss.detach().cpu().item()
def train(args): num_epoch = args.epoch learning_rate = args.learning_rate task_dir = args.task trainset = MedicalDataset(task_dir=task_dir, mode='train' ) validset = MedicalDataset(task_dir=task_dir, mode='valid') model = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0) model.to(device) #summary(model, (1,320,320)) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0) criterion = torch.nn.BCELoss() for epoch in range(num_epoch): model.train() while trainset.iteration < args.iteration: x, y = trainset.next() x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda() #print(x.size(), y.size()) #output = torch.nn.Sigmoid()(model(x)) model.forward(x,y,training=True) elbo = model.elbo(y) reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers) loss = -elbo + 1e-5 * reg_loss #loss = criterion(output, y) optimizer.zero_grad() loss.backward() optimizer.step() trainset.iteration = 0 model.eval() with torch.no_grad(): while validset.iteration < args.test_iteration: x, y = validset.next() x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda() #output = torch.nn.Sigmoid()(model(x, y)) model.forward(x,y,training=True) elbo = model.elbo(y) reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers) valid_loss = -elbo + 1e-5 * reg_loss validset.iteration = 0 print('Epoch: {}, elbo: {:.4f}, regloss: {:.4f}, loss: {:.4f}, valid loss: {:.4f}'.format(epoch+1, elbo.item(), reg_loss.item(), loss.item(), valid_loss.item())) """ #Logger # 1. Log scalar values (scalar summary) info = { 'loss': loss.item(), 'accuracy': valid_loss.item() } for tag, value in info.items(): Logger.scalar_summary(tag, value, epoch+1) # 2. Log values and gradients of the parameters (histogram summary) for tag, value in model.named_parameters(): tag = tag.replace('.', '/') Logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1) Logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1) """ torch.save(model.state_dict(), './save/'+trainset.task_dir+'model.pth')