def train(model, train_loader, device, optimizer, epoch, train_dict, logger): model.train() train_loss = 0 train_dice0 = 0 train_dice1 = 0 train_dice2 = 0 train_dice3 = 0 for batch_idx, (data, target) in enumerate(train_loader): data = torch.squeeze(data, dim=0) target = torch.squeeze(target, dim=0) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) optimizer.zero_grad() # loss = nn.CrossEntropyLoss()(output,target) #loss=metrics.SoftDiceLoss()(output,target) # loss=nn.MSELoss()(output,target) loss = metrics.DiceMeanLoss()(output, target) # loss=metrics.WeightDiceLoss()(output,target) # loss=metrics.CrossEntropy()(output,target) loss.backward() optimizer.step() train_loss = loss train_dice0 = metrics.dice(output, target, 0) train_dice1 = metrics.dice(output, target, 1) train_dice2 = metrics.dice(output, target, 2) train_dice3 = metrics.dice(output, target, 3) print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tdice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\tdice3: {:.6f}\tT: {:.6f}\tP: {:.6f}\tTP: {:.6f}' .format(epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), loss.item(), train_dice0, train_dice1, train_dice2, train_dice3, metrics.T(output, target), metrics.P(output, target), metrics.TP(output, target))) train_dict['loss'].append(float(train_loss)) train_dict['dice0'].append(float(train_dice0)) train_dict['dice1'].append(float(train_dice1)) train_dict['dice2'].append(float(train_dice2)) train_dict['dice3'].append(float(train_dice3)) logger.scalar_summary('train_loss', train_loss, epoch) logger.scalar_summary('train_dice0', train_dice0, epoch) logger.scalar_summary('train_dice1', train_dice1, epoch) logger.scalar_summary('train_dice2', train_dice2, epoch) logger.scalar_summary('train_dice3', train_dice3, epoch)
def val(model, val_loader, device, epoch, val_dict, logger): model.eval() val_loss = 0 val_dice0 = 0 val_dice1 = 0 val_dice2 = 0 val_dice3 = 0 with torch.no_grad(): for data, target in val_loader: data = torch.squeeze(data, dim=0) target = torch.squeeze(target, dim=0) data, target = data.float(), target.float() data, target = data.to(device), target.to(device) output = model(data) loss = metrics.DiceMeanLoss()(output, target) #loss = metrics.SoftDiceLoss()(output, target) dice0 = metrics.dice(output, target, 0) dice1 = metrics.dice(output, target, 1) dice2 = metrics.dice(output, target, 2) dice3 = metrics.dice(output, target, 3) val_loss += float(loss) val_dice0 += float(dice0) val_dice1 += float(dice1) val_dice2 += float(dice2) val_dice3 += float(dice3) val_loss /= len(val_loader) val_dice0 /= len(val_loader) val_dice1 /= len(val_loader) val_dice2 /= len(val_loader) val_dice3 /= len(val_loader) val_dict['loss'].append(float(val_loss)) val_dict['dice0'].append(float(val_dice0)) val_dict['dice1'].append(float(val_dice1)) val_dict['dice2'].append(float(val_dice2)) val_dict['dice3'].append(float(val_dice3)) logger.scalar_summary('val_loss', val_loss, epoch) logger.scalar_summary('val_dice0', val_dice0, epoch) logger.scalar_summary('val_dice1', val_dice1, epoch) logger.scalar_summary('val_dice2', val_dice2, epoch) logger.scalar_summary('val_dice3', val_dice3, epoch) print( '\nVal set: Average loss: {:.6f}, dice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\tdice3: {:.6f}\t\n' .format(val_loss, val_dice0, val_dice1, val_dice2, val_dice3))
# model_dict.update(pretrained_dict) # Flownet.load_state_dict(model_dict) # # pretrained_dict = torch.load('./pkl/net_epoch_126-Seg-Network.pkl') # model_dict = Segnet.state_dict() # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # model_dict.update(pretrained_dict) # Segnet.load_state_dict(model_dict) ## criterion_L1 = torch.nn.L1Loss() criterion_MSE = torch.nn.MSELoss() criterion_BCE = torch.nn.BCEWithLogitsLoss() criterion_CE = criterion.crossentry() criterion_ncc = criterion.NCC().loss criterion_grad = criterion.Grad('l2', 2).loss criterion_dice = criterion.DiceMeanLoss() opt_flow = torch.optim.Adam(Flownet.parameters(), lr=0.0001) opt_seg = torch.optim.Adam(Segnet.parameters(), lr=0.0001) for epoch in range(200): loss_continous_motion = 0 meanncc = 0 meanregdice = 0 meansegdice = 0 for step, (img_ed, img_pre, img_mid, img_aft, img_es, labeled, labeles) in enumerate(dataloder): img_ed = img_ed.to(device).float() img_pre = img_pre.to(device).float() ##ed-0.25 img_mid = img_mid.to(device).float()