def train(args): # Setup Dataloader wc_data_loader = get_loader('doc3dwc') data_path = args.data_path wc_t_loader = wc_data_loader(data_path, is_transform=True, img_size=(args.wc_img_rows, args.wc_img_cols), augmentations=args.augmentation) wc_v_loader = wc_data_loader(data_path, is_transform=True, split='val', img_size=(args.wc_img_rows, args.wc_img_cols)) wc_n_classes = wc_t_loader.n_classes wc_trainloader = data.DataLoader(wc_t_loader, batch_size=args.batch_size, num_workers=8, shuffle=True) wc_valloader = data.DataLoader(wc_v_loader, batch_size=args.batch_size, num_workers=8) # Setup Model model_wc = get_model('unetnc', wc_n_classes, in_channels=3) model_wc = torch.nn.DataParallel(model_wc, device_ids=range( torch.cuda.device_count())) model_wc.cuda() # Setup Dataloader bm_data_loader = get_loader('doc3dbmnic') bm_t_loader = bm_data_loader(data_path, is_transform=True, img_size=(args.bm_img_rows, args.bm_img_cols)) bm_v_loader = bm_data_loader(data_path, is_transform=True, split='val', img_size=(args.bm_img_rows, args.bm_img_cols)) bm_n_classes = bm_t_loader.n_classes bm_trainloader = data.DataLoader(bm_t_loader, batch_size=args.batch_size, num_workers=8, shuffle=True) bm_valloader = data.DataLoader(bm_v_loader, batch_size=args.batch_size, num_workers=8) # Setup Model model_bm = get_model('dnetccnl', bm_n_classes, in_channels=3) model_bm = torch.nn.DataParallel(model_bm, device_ids=range( torch.cuda.device_count())) model_bm.cuda() if os.path.isfile(args.shape_net_loc): print("Loading model_wc from checkpoint '{}'".format( args.shape_net_loc)) checkpoint = torch.load(args.shape_net_loc) model_wc.load_state_dict(checkpoint['model_state']) print("Loaded checkpoint '{}' (epoch {})".format( args.shape_net_loc, checkpoint['epoch'])) else: print("No model_wc checkpoint found at '{}'".format( args.shape_net_loc)) exit(1) if os.path.isfile(args.texture_mapping_net_loc): print("Loading model_bm from checkpoint '{}'".format( args.texture_mapping_net_loc)) checkpoint = torch.load(args.texture_mapping_net_loc) model_bm.load_state_dict(checkpoint['model_state']) print("Loaded checkpoint '{}' (epoch {})".format( args.texture_mapping_net_loc, checkpoint['epoch'])) else: print("No model_bm checkpoint found at '{}'".format( args.texture_mapping_net_loc)) exit(1) # Activation htan = nn.Hardtanh(0, 1.0) # Optimizer optimizer = torch.optim.Adam(list(model_wc.parameters()) + list(model_bm.parameters()), lr=args.l_rate, weight_decay=5e-4, amsgrad=True) # LR Scheduler sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True) # Losses MSE = nn.MSELoss() loss_fn = nn.L1Loss() gloss = grad_loss.Gradloss(window_size=5, padding=2) reconst_loss = recon_lossc.Unwarploss() epoch_start = 0 # Log file: if not os.path.exists(args.logdir): os.makedirs(args.logdir) experiment_name = 'joint train' log_file_name = os.path.join(args.logdir, experiment_name + '.txt') if os.path.isfile(log_file_name): log_file = open(log_file_name, 'a') else: log_file = open(log_file_name, 'w+') log_file.write('\n--------------- ' + experiment_name + ' ---------------\n') log_file.close() # Setup tensorboard for visualization if args.tboard: # save logs in runs/<experiment_name> writer = SummaryWriter(comment=experiment_name) best_val_mse = 99999.0 global_step = 0 LClambda = 0.2 bm_img_size = (128, 128) alpha = 0.5 beta = 0.5 for epoch in range(epoch_start, args.n_epoch): avg_loss = 0.0 wc_avg_l1loss = 0.0 wc_avg_gloss = 0.0 wc_train_mse = 0.0 bm_avgl1loss = 0.0 bm_avgrloss = 0.0 bm_avgssimloss = 0.0 bm_train_mse = 0.0 model_wc.train() model_bm.train() if epoch == 50 and LClambda < 1.0: LClambda += 0.2 for (i, (wc_images, wc_labels)), (i, (bm_images, bm_labels)) in zip( enumerate(wc_trainloader), enumerate(bm_trainloader)): wc_images = Variable(wc_images.cuda()) wc_labels = Variable(wc_labels.cuda()) optimizer.zero_grad() wc_outputs = model_wc(wc_images) pred_wc = htan(wc_outputs) g_loss = gloss(pred_wc, wc_labels) wc_l1loss = loss_fn(pred_wc, wc_labels) loss = alpha * (wc_l1loss + LClambda * g_loss) bm_images = Variable(bm_images.cuda()) bm_labels = Variable(bm_labels.cuda()) bm_input = F.interpolate(pred_wc, bm_img_size) target = model_bm(bm_input) target_nhwc = target.transpose(1, 2).transpose(2, 3) bm_val_l1loss = loss_fn(target_nhwc, bm_labels) rloss, ssim, uworg, uwpred = reconst_loss(bm_images[:, :-1, :, :], target_nhwc, bm_labels) loss += beta * ((10.0 * bm_val_l1loss) + (0.5 * rloss)) avg_loss += float(loss) wc_avg_l1loss += float(wc_l1loss) wc_avg_gloss += float(g_loss) wc_train_mse += float(MSE(pred_wc, wc_labels).item()) bm_avgl1loss += float(bm_val_l1loss) bm_avgrloss += float(rloss) bm_avgssimloss += float(ssim) bm_train_mse += float(MSE(target_nhwc, bm_labels).item()) loss.backward() optimizer.step() global_step += 1 if (i + 1) % 50 == 0: print("Epoch[%d/%d] Batch [%d/%d] Loss: %.4f" % (epoch + 1, args.n_epoch, i + 1, len(wc_trainloader), avg_loss / 50.0)) avg_loss = 0.0 if args.tboard and (i + 1) % 20 == 0: show_wc_tnsboard(global_step, writer, wc_images, wc_labels, pred_wc, 8, 'Train Inputs', 'Train WCs', 'Train pred_wc. WCs') writer.add_scalar('WC: L1 Loss/train', wc_avg_l1loss / (i + 1), global_step) writer.add_scalar('WC: Grad Loss/train', wc_avg_gloss / (i + 1), global_step) show_unwarp_tnsboard(bm_images, global_step, writer, uwpred, uworg, 8, 'Train GT unwarp', 'Train Pred Unwarp') writer.add_scalar('BM: L1 Loss/train', bm_avgl1loss / (i + 1), global_step) writer.add_scalar('CB: Recon Loss/train', bm_avgrloss / (i + 1), global_step) writer.add_scalar('CB: SSIM Loss/train', bm_avgssimloss / (i + 1), global_step) wc_train_mse = wc_train_mse / len(wc_trainloader) wc_avg_l1loss = wc_avg_l1loss / len(wc_trainloader) wc_avg_gloss = wc_avg_gloss / len(wc_trainloader) print("wc Training L1:%4f" % (wc_avg_l1loss)) print("wc Training MSE:'{}'".format(wc_train_mse)) wc_train_losses = [wc_avg_l1loss, wc_train_mse, wc_avg_gloss] lrate = get_lr(optimizer) write_log_file(log_file_name, wc_train_losses, epoch + 1, lrate, 'Train', 'wc') bm_avgssimloss = bm_avgssimloss / len(bm_trainloader) bm_avgrloss = bm_avgrloss / len(bm_trainloader) bm_avgl1loss = bm_avgl1loss / len(bm_trainloader) bm_train_mse = bm_train_mse / len(bm_trainloader) print("bm Training L1:%4f" % (bm_avgl1loss)) print("bm Training MSE:'{}'".format(bm_train_mse)) bm_train_losses = [ bm_avgl1loss, bm_train_mse, bm_avgrloss, bm_avgssimloss ] write_log_file(log_file_name, bm_train_losses, epoch + 1, lrate, 'Train', 'bm') model_wc.eval() model_bm.eval() val_mse = 0.0 val_loss = 0.0 wc_val_loss = 0.0 wc_val_gloss = 0.0 wc_val_mse = 0.0 bm_val_l1loss = 0.0 val_rloss = 0.0 val_ssimloss = 0.0 bm_val_mse = 0.0 for (i_val, (wc_images_val, wc_labels_val)), (i_val, (bm_images_val, bm_labels_val)) in tqdm( zip(enumerate(wc_valloader), enumerate(bm_valloader))): with torch.no_grad(): wc_images_val = Variable(wc_images_val.cuda()) wc_labels_val = Variable(wc_labels_val.cuda()) wc_outputs = model_wc(wc_images_val) pred_val = htan(wc_outputs) wc_g_loss = gloss(pred_val, wc_labels_val).cpu() pred_val = pred_val.cpu() wc_labels_val = wc_labels_val.cpu() wc_val_loss += loss_fn(pred_val, wc_labels_val) wc_val_mse += float(MSE(pred_val, wc_labels_val)) wc_val_gloss += float(wc_g_loss) bm_images_val = Variable(bm_images_val.cuda()) bm_labels_val = Variable(bm_labels_val.cuda()) bm_input = F.interpolate(pred_val, bm_img_size) target = model_bm(bm_input) target_nhwc = target.transpose(1, 2).transpose(2, 3) pred = target_nhwc.data.cpu() gt = bm_labels_val.cpu() bm_val_l1loss += loss_fn(target_nhwc, bm_labels_val) rloss, ssim, uworg, uwpred = reconst_loss( bm_images_val[:, :-1, :, :], target_nhwc, bm_labels_val) val_rloss += float(rloss.cpu()) val_ssimloss += float(ssim.cpu()) bm_val_mse += float(MSE(pred, gt)) val_loss += (alpha * wc_val_loss + beta * bm_val_l1loss) val_mse += (wc_val_mse + bm_val_mse) if args.tboard: show_unwarp_tnsboard(bm_images_val, epoch + 1, writer, uwpred, uworg, 8, 'Val GT unwarp', 'Val Pred Unwarp') if args.tboard: show_wc_tnsboard(epoch + 1, writer, wc_images_val, wc_labels_val, pred_val, 8, 'Val Inputs', 'Val WCs', 'Val Pred. WCs') writer.add_scalar('WC: L1 Loss/val', wc_val_loss, epoch + 1) writer.add_scalar('WC: Grad Loss/val', wc_val_gloss, epoch + 1) writer.add_scalar('BM: L1 Loss/val', bm_val_l1loss, epoch + 1) writer.add_scalar('CB: Recon Loss/val', val_rloss, epoch + 1) writer.add_scalar('CB: SSIM Loss/val', val_ssimloss, epoch + 1) writer.add_scalar('total val loss', val_loss, epoch + 1) wc_val_loss = wc_val_loss / len(wc_valloader) wc_val_mse = wc_val_mse / len(wc_valloader) wc_val_gloss = wc_val_gloss / len(wc_valloader) print("wc val loss at epoch {}:: {}".format(epoch + 1, wc_val_loss)) print("wc val MSE: {}".format(wc_val_mse)) bm_val_l1loss = bm_val_l1loss / len(bm_valloader) bm_val_mse = bm_val_mse / len(bm_valloader) val_ssimloss = val_ssimloss / len(bm_valloader) val_rloss = val_rloss / len(bm_valloader) print("bm val loss at epoch {}:: {}".format(epoch + 1, bm_val_l1loss)) print("bm val mse: {}".format(bm_val_mse)) val_loss /= len(wc_valloader) val_mse /= len(wc_valloader) print("val loss at epoch {}:: {}".format(epoch + 1, val_loss)) print("val mse: {}".format(val_mse)) bm_val_losses = [bm_val_l1loss, bm_val_mse, val_rloss, val_ssimloss] wc_val_losses = [wc_val_loss, wc_val_mse, wc_val_gloss] total_val_losses = [val_loss, val_mse] write_log_file(log_file_name, wc_val_losses, epoch + 1, lrate, 'Val', 'wc') write_log_file(log_file_name, bm_val_losses, epoch + 1, lrate, 'Val', 'bm') write_log_file(log_file_name, total_val_losses, epoch + 1, lrate, 'Val', 'total') # reduce learning rate sched.step(val_mse) if val_mse < best_val_mse: best_val_mse = val_mse state_wc = { 'epoch': epoch + 1, 'model_state': model_wc.state_dict(), 'optimizer_state': optimizer.state_dict(), } torch.save( state_wc, args.logdir + "{}_{}_{}_{}_{}_best_wc_model.pkl".format( 'unetnc', epoch + 1, wc_val_mse, wc_train_mse, experiment_name)) state_bm = { 'epoch': epoch + 1, 'model_state': model_bm.state_dict(), 'optimizer_state': optimizer.state_dict(), } torch.save( state_bm, args.logdir + "{}_{}_{}_{}_{}_best_bm_model.pkl".format( 'dnetccnl', epoch + 1, bm_val_mse, bm_train_mse, experiment_name)) if (epoch + 1) % 10 == 0 and epoch > 70: state_wc = { 'epoch': epoch + 1, 'model_state': model_wc.state_dict(), 'optimizer_state': optimizer.state_dict(), } torch.save( state_wc, args.logdir + "{}_{}_{}_{}_{}_wc_model.pkl".format( 'unetnc', epoch + 1, wc_val_mse, wc_train_mse, experiment_name)) state_bm = { 'epoch': epoch + 1, 'model_state': model_bm.state_dict(), 'optimizer_state': optimizer.state_dict(), } torch.save( state_bm, args.logdir + "{}_{}_{}_{}_{}_bm_model.pkl".format( 'dnetccnl', epoch + 1, bm_val_mse, bm_train_mse, experiment_name))
def train(args): # Setup Dataloader data_loader = get_loader('doc3dbmnic') data_path = args.data_path print('Starting . . .') t_loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols)) v_loader = data_loader(data_path, is_transform=True, split='val', img_size=(args.img_rows, args.img_cols)) n_classes = t_loader.n_classes print('Loading training data . . .') trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8, shuffle=True) print('Loading validation data . . .') valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=8) # Setup Model print('Loading model . . .') model = get_model(args.arch, n_classes, in_channels=3) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.cuda() # Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate, weight_decay=5e-4, amsgrad=True) # LR Scheduler sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) # Losses MSE = nn.MSELoss() loss_fn = nn.L1Loss() reconst_loss = recon_lossc.Unwarploss() epoch_start = 0 if args.resume is not None: if os.path.isfile(args.resume): print("Loading model and optimizer from checkpoint '{}'".format( args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) print("Loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) epoch_start = checkpoint['epoch'] else: print("No checkpoint found at '{}'".format(args.resume)) # Log file: if not os.path.exists(args.logdir): os.makedirs(args.logdir) # network_activation(t=[-1,1])_dataset_lossparams_augmentations_trainstart experiment_name = 'dnetccnl_htan_swat3dmini1kbm_l1_noaug_scratch' log_file_name = os.path.join(args.logdir, experiment_name + '.txt') if os.path.isfile(log_file_name): log_file = open(log_file_name, 'a') else: log_file = open(log_file_name, 'w+') log_file.write('\n--------------- ' + experiment_name + ' ---------------\n') log_file.close() # Setup tensorboard for visualization if args.tboard: # save logs in runs/<experiment_name> writer = SummaryWriter(comment=experiment_name) best_val_uwarpssim = 99999.0 best_val_mse = 99999.0 global_step = 0 for epoch in range(epoch_start, args.n_epoch): avg_loss = 0.0 avgl1loss = 0.0 avgrloss = 0.0 avgssimloss = 0.0 train_mse = 0.0 model.train() for i, (images, labels) in enumerate(trainloader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) optimizer.zero_grad() target = model(images[:, 3:, :, :]) target_nhwc = target.transpose(1, 2).transpose(2, 3) l1loss = loss_fn(target_nhwc, labels) rloss, ssim, uworg, uwpred = reconst_loss(images[:, :-1, :, :], target_nhwc, labels) loss = (10.0 * l1loss) + (0.5 * rloss) # + (0.3*ssim) # loss=l1loss avgl1loss += float(l1loss) avg_loss += float(loss) avgrloss += float(rloss) avgssimloss += float(ssim) train_mse += MSE(target_nhwc, labels).item() loss.backward() optimizer.step() global_step += 1 if (i + 1) % 10 == 0: avg_loss = avg_loss / 10 print("Epoch[%d/%d] Batch [%d/%d] Loss: %.4f" % (epoch + 1, args.n_epoch, i + 1, len(trainloader), avg_loss)) avg_loss = 0.0 if args.tboard and (i + 1) % 10 == 0: show_unwarp_tnsboard(global_step, writer, uwpred, uworg, 8, 'Train GT unwarp', 'Train Pred Unwarp') writer.add_scalars( 'Train', { 'BM_L1 Loss/train': avgl1loss / (i + 1), 'CB_Recon Loss/train': avgrloss / (i + 1), 'CB_SSIM Loss/train': avgssimloss / (i + 1) }, global_step) # writer.add_scalar('BM: L1 Loss/train', # avgl1loss/(i+1), global_step) # writer.add_scalar('CB: Recon Loss/train', # avgrloss/(i+1), global_step) # writer.add_scalar('CB: SSIM Loss/train', # avgssimloss/(i+1), global_step) avgssimloss = avgssimloss / len(trainloader) avgrloss = avgrloss / len(trainloader) avgl1loss = avgl1loss / len(trainloader) train_mse = train_mse / len(trainloader) print("Training L1:%4f" % (avgl1loss)) print("Training MSE:'{}'".format(train_mse)) train_losses = [avgl1loss, train_mse, avgrloss, avgssimloss] lrate = get_lr(optimizer) write_log_file(log_file_name, train_losses, epoch + 1, lrate, 'Train') if args.tboard: writer.add_scalar('BM: L1 Loss/train', avgl1loss, epoch + 1) writer.add_scalar('CB: Recon Loss/train', avgrloss, epoch + 1) writer.add_scalar('CB: SSIM Loss/train', avgssimloss, epoch + 1) writer.add_scalar('MSE: MSE/train', train_mse, epoch + 1) model.eval() val_loss = 0.0 val_l1loss = 0.0 val_mse = 0.0 val_rloss = 0.0 val_ssimloss = 0.0 for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)): with torch.no_grad(): images_val = Variable(images_val.cuda()) labels_val = Variable(labels_val.cuda()) target = model(images_val[:, 3:, :, :]) target_nhwc = target.transpose(1, 2).transpose(2, 3) pred = target_nhwc.data.cpu() gt = labels_val.cpu() l1loss = loss_fn(target_nhwc, labels_val) rloss, ssim, uworg, uwpred = reconst_loss( images_val[:, :-1, :, :], target_nhwc, labels_val) val_l1loss += float(l1loss.cpu()) val_rloss += float(rloss.cpu()) val_ssimloss += float(ssim.cpu()) val_mse += float(MSE(pred, gt)) if args.tboard: show_unwarp_tnsboard(epoch + 1, writer, uwpred, uworg, 8, 'Val GT unwarp', 'Val Pred Unwarp') val_l1loss = val_l1loss / len(valloader) val_mse = val_mse / len(valloader) val_ssimloss = val_ssimloss / len(valloader) val_rloss = val_rloss / len(valloader) print("val loss at epoch {}:: {}".format(epoch + 1, val_l1loss)) print("val mse: {}".format(val_mse)) val_losses = [val_l1loss, val_mse, val_rloss, val_ssimloss] write_log_file(log_file_name, val_losses, epoch + 1, lrate, 'Val') if args.tboard: # log the val losses writer.add_scalar('BM: L1 Loss/val', val_l1loss, epoch + 1) writer.add_scalar('CB: Recon Loss/val', val_rloss, epoch + 1) writer.add_scalar('CB: SSIM Loss/val', val_ssimloss, epoch + 1) writer.add_scalar('MSE: MSE/val', val_mse, epoch + 1) if args.tboard: # plot train against val writer.add_scalars('BM_L1_Loss', { 'train': avgl1loss, 'val': val_l1loss }, epoch + 1) writer.add_scalars('CB_Recon_Loss', { 'train': avgrloss, 'val': val_rloss }, epoch + 1) writer.add_scalars('CB_SSIM_Loss', { 'train': avgssimloss, 'val': val_ssimloss }, epoch + 1) writer.add_scalars('MSE_Mean_square_error', { 'train': train_mse, 'val': val_mse }, epoch + 1) # reduce learning rate sched.step(val_mse) if val_mse < best_val_mse: best_val_mse = val_mse state = { 'epoch': epoch + 1, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), } torch.save( state, args.logdir + "{}_{}_{}_{}_{}_best_model.pkl".format( args.arch, epoch + 1, val_mse, train_mse, experiment_name)) if (epoch + 1) % 10 == 0: state = { 'epoch': epoch + 1, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), } torch.save( state, args.logdir + "{}_{}_{}_{}_{}_model.pkl".format( args.arch, epoch + 1, val_mse, train_mse, experiment_name))
def train(n_epoch=50, batch_size=32, resume=False, wc_path='', bm_path=''): wc_model_name = 'unetnc' bm_model_name = 'dnetccnl' # Setup dataloader data_path = 'C:/Users/yuttapichai.lam/dev-environment/doc3d' data_loader = get_loader('doc3djoint') t_loader = data_loader(data_path, is_transform=True, img_size=(256, 256), bm_size=(128, 128)) v_loader = data_loader(data_path, split='val', is_transform=True, img_size=(256, 256), bm_size=(128, 128)) trainloader = data.DataLoader(t_loader, batch_size=batch_size, num_workers=8, shuffle=True) valloader = data.DataLoader(v_loader, batch_size=batch_size, num_workers=8) # Last layer activation htan = nn.Hardtanh(0, 1.0) # Load models print('Loading') wc_model = get_model(wc_model_name, n_classes=3, in_channels=3) wc_model = torch.nn.DataParallel(wc_model, device_ids=range( torch.cuda.device_count())) wc_model.cuda() bm_model = get_model(bm_model_name, n_classes=2, in_channels=3) bm_model = torch.nn.DataParallel(bm_model, device_ids=range( torch.cuda.device_count())) bm_model.cuda() # Setup optimizer and learning rate reduction print('Setting optimizer') optimizer = torch.optim.Adam([{ 'params': wc_model.parameters() }, { 'params': bm_model.parameters() }], lr=1e-4, weight_decay=5e-4, amsgrad=True) schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) # Setup losses MSE = nn.MSELoss() loss_fn = nn.L1Loss() reconst_loss = recon_lossc.Unwarploss() g_loss = grad_loss.Gradloss(window_size=5, padding=2) epoch_start = 0 if resume: print('Resume from previous state') wc_chkpnt = torch.load(wc_path) wc_model.load_state_dict(wc_chkpnt['model_state']) bm_chkpnt = torch.load(bm_path) bm_model.load_state_dict(bm_chkpnt['model_state']) # optimizer.load_state_dict( # [wc_chkpnt['optimizer_state'], bm_chkpnt['optimizer_state']]) epoch_start = bm_chkpnt['epoch'] best_valwc_mse = 9999999.0 best_valbm_mse = 9999999.0 print(f'Start from epoch {epoch_start} of {n_epoch}') print('Starting') for epoch in range(epoch_start, n_epoch): print(f'Epoch: {epoch}') # Loss initialization avg_loss = 0.0 avg_wcloss = 0.0 avgwcl1loss = 0.0 avg_gloss = 0.0 train_wcmse = 0.0 avg_bmloss = 0.0 avgbml1loss = 0.0 avgrloss = 0.0 avgssimloss = 0.0 train_bmmse = 0.0 avg_const_l1 = 0.0 avg_const_mse = 0.0 # Start training wc_model.train() bm_model.train() print('Training') for i, (imgs, wcs, bms, recons, ims, lbls) in enumerate(trainloader): images = Variable(imgs.cuda()) wc_labels = Variable(wcs.cuda()) bm_labels = Variable(bms.cuda()) recon_labels = Variable(recons.cuda()) im_inputs = Variable(ims.cuda()) labels = Variable(lbls.cuda()) optimizer.zero_grad() # Train WC network wc_out = wc_model(images) wc_out = F.interpolate(wc_out, size=(256, 256), mode='bilinear', align_corners=True) bm_inp = F.interpolate(wc_out, size=(128, 128), mode='bilinear', align_corners=True) bm_inp = htan(bm_inp) wc_pred = htan(wc_out) wc_l1loss = loss_fn(wc_pred, wc_labels) wc_gloss = g_loss(wc_pred, wc_labels) wc_mse = MSE(wc_pred, wc_labels) wc_loss = wc_l1loss + (0.2 * wc_gloss) # WC Loss avgwcl1loss += float(wc_l1loss) avg_gloss += float(wc_gloss) train_wcmse += float(wc_mse) avg_wcloss += float(wc_loss) # Train BM network bm_out = bm_model(bm_inp) bm_out = bm_out.transpose(1, 2).transpose(2, 3) bm_l1loss = loss_fn(bm_out, bm_labels) rloss, ssim, _, _ = reconst_loss(recon_labels, bm_out, bm_labels) bm_mse = MSE(bm_out, bm_labels) bm_loss = (10.0 * bm_l1loss) + (0.5 * rloss) # Loss between unwarped GT and unwarped Predict im_ins = im_inputs[:, :3, :, :] bm_out = bm_out.double() label_in = labels[:, :3, :, :] bm_labels = bm_labels.double() uwpred = unwarp(im_ins, bm_out) uworg = unwarp(label_in, bm_labels) const_l1 = loss_fn(uwpred, uworg) const_mse = MSE(uwpred, uworg) # BM Loss avg_const_l1 += float(const_l1) avg_const_mse += float(const_mse) avgbml1loss += float(bm_l1loss) avgrloss += float(rloss) avgssimloss += float(ssim) train_bmmse += float(bm_mse) avg_bmloss += float(bm_loss) # Step loss loss = (0.5 * wc_loss) + (0.5 * bm_loss) avg_loss += float(loss) # print(f'Epoch[{epoch}/{n_epoch}] Loss: {loss:.6f} Const Loss: {const_l1:.6f}') if (i + 1) % 10 == 0: # Show image _, ax = plt.subplots(1, 2) ax[0].imshow(uworg[0].cpu().detach().numpy().transpose( (1, 2, 0))) ax[1].imshow(uwpred[0].cpu().detach().numpy().transpose( (1, 2, 0))) plt.show() print( f'Epoch[{epoch}/{n_epoch}] Batch[{i+1}/{len(trainloader)}] Loss: {avg_loss/(i+1):.6f} Const Loss: {avg_const_l1/(i+1):.6f}' ) loss.backward() # const_l1.backward() optimizer.step() len_trainset = len(trainloader) avg_const_l1 = avg_const_l1 / len_trainset train_wcmse = train_wcmse / len_trainset train_bmmse = train_bmmse / len_trainset train_losses = [ avgwcl1loss / len_trainset, train_wcmse, avg_gloss / len_trainset, avgbml1loss / len_trainset, train_bmmse, avgrloss / len_trainset, avgssimloss / len_trainset, avg_const_l1, avg_const_mse / len_trainset ] print( f'WC L1 loss: {train_losses[0]} WC MSE: {train_losses[1]} WC GLoss: {train_losses[2]}' ) print( f'BM L1 Loss: {train_losses[3]} BM MSE: {train_losses[4]} BM RLoss: {train_losses[5]} BM SSIM Loss: {train_losses[6]}' ) print( f'Reconstruction against GT => Loss: {train_losses[7]} MSE" {train_losses[8]}' ) wc_model.eval() bm_model.eval() wc_val_l1 = 0.0 wc_val_mse = 0.0 wc_val_gloss = 0.0 bm_val_l1 = 0.0 bm_val_mse = 0.0 bm_val_rloss = 0.0 bm_val_ssim = 0.0 avg_const_l1_val = 0.0 avg_const_mse_val = 0.0 print('Validating') for i_val, (imgs_val, wcs_val, bms_val, recons_val, ims_val, lbls_val) in tqdm(enumerate(valloader)): with torch.no_grad(): images_val = Variable(imgs_val.cuda()) wc_labels_val = Variable(wcs_val.cuda()) bm_labels_val = Variable(bms_val.cuda()) recon_labels_val = Variable(recons_val.cuda()) ims_labels_val = Variable(ims_val.cuda()) labels_val = Variable(lbls_val.cuda()) # Val WC Network wc_out_val = wc_model(images_val) wc_out_val = F.interpolate(wc_out_val, size=(256, 256), mode='bilinear', align_corners=True) bm_inp_val = F.interpolate(wc_out_val, size=(128, 128), mode='bilinear', align_corners=True) bm_inp_val = htan(bm_inp_val) wc_pred_val = htan(wc_out_val) wc_l1 = loss_fn(wc_pred_val, wc_labels_val) wc_gloss = g_loss(wc_pred_val, wc_labels_val) wc_mse = MSE(wc_pred_val, wc_labels_val) # Val BM network bm_out_val = bm_model(bm_inp_val) bm_out_val = bm_out_val.transpose(1, 2).transpose(2, 3) bm_l1 = loss_fn(bm_out_val, bm_labels_val) rloss, ssim, _, _ = reconst_loss(recon_labels_val, bm_out_val, bm_labels_val) bm_mse = MSE(bm_out_val, bm_labels_val) # Loss between unwarped GT and unwarped Predict im_ins_val = ims_labels_val[:, :3, :, :] bm_out_val = bm_out_val.double() lbl_ins_val = labels_val[:, :3, :, :] bm_labels_val = bm_labels_val.double() uwpred_val = unwarp(im_ins_val, bm_out_val) uworg_val = unwarp(lbl_ins_val, bm_labels_val) const_l1_val = loss_fn(uwpred_val, uworg_val) const_mse_val = MSE(uwpred_val, uworg_val) # Val Loss avg_const_l1_val += float(const_l1_val) avg_const_mse_val += float(const_mse_val) wc_val_l1 += float(wc_l1.cpu()) wc_val_gloss += float(wc_gloss.cpu()) wc_val_mse += float(wc_mse.cpu()) bm_val_l1 += float(bm_l1.cpu()) bm_val_mse += float(bm_mse.cpu()) bm_val_rloss += float(rloss.cpu()) bm_val_ssim += float(ssim.cpu()) len_valset = len(valloader) avg_const_l1_val = avg_const_l1_val / len_valset wc_val_mse = wc_val_mse / len_valset bm_val_mse = bm_val_mse / len_valset val_losses = [ wc_val_l1 / len_valset, wc_val_mse, wc_val_gloss / len_valset, bm_val_l1 / len_valset, bm_val_mse, bm_val_rloss / len_valset, bm_val_ssim / len_valset, avg_const_l1_val, avg_const_mse_val / len_valset ] print( f'WC L1 loss: {val_losses[0]} WC MSE: {val_losses[1]} WC GLoss: {val_losses[2]}' ) print( f'BM L1 Loss: {val_losses[3]} BM MSE: {val_losses[4]} BM RLoss: {val_losses[5]} BM SSIM Loss: {val_losses[6]}' ) print( f'Reconstruction against GT => Loss: {val_losses[7]} MSE" {val_losses[8]}' ) # Reduce learning rate schedule.step(bm_val_mse) if wc_val_mse < best_valwc_mse: best_valwc_mse = wc_val_mse state = {'epoch': epoch, 'model_state': wc_model.state_dict()} torch.save( state, f'./checkpoints-wc/unetnc_{epoch}_wc_{wc_val_mse}_{train_wcmse}_best_model.pkl' ) if bm_val_mse < best_valbm_mse: best_valbm_mse = bm_val_mse state = {'epoch': epoch, 'model_state': bm_model.state_dict()} torch.save( state, f'./checkpoints-bm/dnetccnl_{epoch}_bm_{bm_val_mse}_{train_bmmse}_best_model.pkl' )