def __init__(self, fix_im, **kwargs): super(SSIMRegularization, self).__init__(fix_im) if 'window_size' in kwargs: self.ssim_instance = ssim.SSIM(window_size=kwargs['window_size']) else: self.ssim_instance = ssim.SSIM() manual_gpu = kwargs.get('manual_gpu', None) if manual_gpu is not None: self.use_gpu = manual_gpu else: self.use_gpu = utils.use_gpu()
def SSIM(output, target): ssim = pytorch_ssim.SSIM(window_size=11) total_ssim = 0. n_frames = target.shape[1] for f in range(n_frames): total_ssim += ssim(output[:, f], target[:, f]) return total_ssim / n_frames
def __init__(self, recon_loss_name): if recon_loss_name == "L1": self.recon_loss_func = lambda x, y: torch.mean( torch.sum(torch.abs(x - y), dim=(1, 2, 3)), dim=0) elif recon_loss_name == "MSE": self.recon_loss_func = lambda x, y: torch.mean( torch.sum(torch.abs(x - y)**2, dim=(1, 2, 3)), dim=0) elif recon_loss_name == "BCE": raise NotImplementedError elif recon_loss_name == "SSIM": ssim = pytorch_ssim.SSIM(window_size=3).cuda() self.recon_loss_func = lambda x, y: torch.mean( torch.sum(1 - ssim(x, y), dim=(1, 2, 3)), dim=0) elif recon_loss_name == "custom": raise NotImplementedError
def main(args): # use gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu cur_device = torch.device('cuda:{}'.format(args.gpu)) if args.loss == 'bayes': root = '/home/datamining/Datasets/CrowdCounting/sha_bayes_512/' train_path = root + 'train/' test_path = root + 'test/' elif args.bn: root = '/home/datamining/Datasets/CrowdCounting/sha_512_a/' train_path = root + 'train/' test_path = root + 'test/' else: if args.dataset == 'sha': root = '/home/datamining/Datasets/CrowdCounting/shanghaitech/part_A_final/' train_path = root + 'train_data/images' test_path = root + 'test_data/images/' elif args.dataset == 'shb': root = '/home/datamining/Datasets/CrowdCounting/shb_1024_f15/' train_path = root + 'train/' test_path = root + 'test/' elif args.dataset == 'qnrf': root = '/home/datamining/Datasets/CrowdCounting/qnrf_1024_a/' train_path = root + 'train/' test_path = root + 'test/' downsample_ratio = args.downsample train_loader, test_loader, train_img_paths, test_img_paths = get_loader( train_path, test_path, downsample_ratio, args) model_dict = { 'VGG16_13': M_CSRNet, 'DefCcNet': DefCcNet, 'Res50_back3': Res50, 'InceptionV3': Inception3CC, 'CAN': CANNet } model_name = args.model dataset_name = args.dataset net = model_dict[model_name](downsample=args.downsample, bn=args.bn > 0, objective=args.objective, sp=(args.sp > 0), se=(args.se > 0), NL=args.nl) net.cuda() if args.bn > 0: save_name = '{}_{}_{}_bn{}_ps{}_{}'.format(model_name, dataset_name, str(int(args.bn)), str(args.crop_size), args.loss) else: save_name = '{}_d{}{}{}{}{}_{}_{}_cr{}_{}{}{}{}{}{}'.format( model_name, str(args.downsample), '_sp' if args.sp else '', '_se' if args.se else '', '_' + args.nl if args.nl != 'relu' else '', '_vp' if args.val_patch else '', dataset_name, args.crop_mode, str(args.crop_scale), args.loss, '_wu' if args.warm_up else '', '_cl' if args.curriculum == 'W' else '', '_v' + str(int(args.value_factor)) if args.value_factor != 1 else '', '_amp' + str(args.amp_k) if args.objective == 'dmp+amp' else '', '_bg' if args.use_bg else '') save_path = "/home/datamining/Models/CrowdCounting/" + save_name + ".pth" logger = get_logger('logs/' + save_name + '.txt') for k, v in args.__dict__.items(): # save args logger.info("{}: {}".format(k, v)) if os.path.exists(save_path) and args.resume: net.load_state_dict(torch.load(save_path)) print('{} loaded!'.format(save_path)) value_factor = args.value_factor freq = 100 if args.optimizer == 'Adam': optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.decay) elif args.optimizer == 'SGD': # not converage optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.95, weight_decay=args.decay) if args.loss == 'bayes': bayes_criterion = Bay_Loss(True, cur_device) post_prob = Post_Prob(sigma=8.0, c_size=args.crop_size, stride=1, background_ratio=0.15, use_background=True, device=cur_device) else: mse_criterion = nn.MSELoss().cuda() if args.scheduler == 'plt': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True) elif args.scheduler == 'cos': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0) elif args.scheduler == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8) elif args.scheduler == 'exp': scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99) elif args.scheduler == 'cyclic' and args.optimizer == 'SGD': scheduler = lr_scheduler.CyclicLR( optimizer, base_lr=args.lr * 0.01, max_lr=args.lr, step_size_up=25, ) elif args.scheduler == 'None': scheduler = None else: print('scheduler name error!') if args.val_patch: best_mae, best_rmse = val_patch(net, test_loader, value_factor) elif args.loss == 'bayes': best_mae, best_rmse = val_bayes(net, test_loader, value_factor) else: best_mae, best_rmse = val(net, test_loader, value_factor) if args.scheduler == 'plt': scheduler.step(best_mae) ssim_loss = pytorch_ssim.SSIM(window_size=11) for epoch in range(args.epochs): if args.crop_mode == 'curriculum': # every 20%, change the dataset if (epoch + 1) % (args.epochs // 5) == 0: print('change dataset') single_dataset = RawDataset( train_img_paths, transform, args.crop_mode, downsample_ratio, args.crop_scale, (epoch + 1.0 + args.epochs // 5) / args.epochs) train_loader = torch.utils.data.DataLoader(single_dataset, shuffle=True, batch_size=1, num_workers=8) train_loss = 0.0 if args.loss == 'bayes': epoch_mae = AverageMeter() epoch_mse = AverageMeter() net.train() if args.warm_up and epoch < args.warm_up_steps: linear_warm_up_lr(optimizer, epoch, args.warm_up_steps, args.lr) for it, data in enumerate(train_loader): if args.loss == 'bayes': inputs, points, targets, st_sizes = data img = inputs.to(cur_device) st_sizes = st_sizes.to(cur_device) gd_count = np.array([len(p) for p in points], dtype=np.float32) points = [p.to(cur_device) for p in points] targets = [t.to(cur_device) for t in targets] else: img, target, _, amp_gt = data img = img.cuda() target = value_factor * target.float().unsqueeze(1).cuda() amp_gt = amp_gt.cuda() #print(img.shape) optimizer.zero_grad() #print(target.shape) if args.objective == 'dmp+amp': output, amp = net(img) output = output * amp else: output = net(img) if args.curriculum == 'W': delta = (output - target)**2 k_w = 2e-3 * args.value_factor * args.downsample**2 b_w = 5e-3 * args.value_factor * args.downsample**2 T = torch.ones_like(target, dtype=torch.float32) * epoch * k_w + b_w W = T / torch.max(T, output) delta = delta * W mse_loss = torch.mean(delta) else: mse_loss = mse_criterion(output, target) if args.loss == 'mse+lc': loss = mse_loss + 1e2 * cal_lc_loss(output, target) * args.downsample elif args.loss == 'ssim': loss = 1 - ssim_loss(output, target) elif args.loss == 'mse+ssim': loss = 100 * mse_loss + 1e-2 * (1 - ssim_loss(output, target)) elif args.loss == 'mse+la': loss = mse_loss + cal_spatial_abstraction_loss(output, target) elif args.loss == 'la': loss = cal_spatial_abstraction_loss(output, target) elif args.loss == 'ms-ssim': #to do pass elif args.loss == 'adversial': # to do pass elif args.loss == 'bayes': prob_list = post_prob(points, st_sizes) loss = bayes_criterion(prob_list, targets, output) else: loss = mse_loss # add the cross entropy loss for attention map if args.objective == 'dmp+amp': cross_entropy = (amp_gt * torch.log(amp) + (1 - amp_gt) * torch.log(1 - amp)) * -1 cross_entropy_loss = torch.mean(cross_entropy) loss = loss + cross_entropy_loss * args.amp_k loss.backward() optimizer.step() data_loss = loss.item() train_loss += data_loss if args.loss == 'bayes': N = inputs.size(0) pre_count = torch.sum(output.view(N, -1), dim=1).detach().cpu().numpy() res = pre_count - gd_count epoch_mse.update(np.mean(res * res), N) epoch_mae.update(np.mean(abs(res)), N) if args.loss != 'bayes' and it % freq == 0: print( '[ep:{}], [it:{}], [loss:{:.8f}], [output:{:.2f}, target:{:.2f}]' .format(epoch + 1, it, data_loss, output[0].sum().item(), target[0].sum().item())) if args.val_patch: mae, rmse = val_patch(net, test_loader, value_factor) elif args.loss == 'bayes': mae, rmse = val_bayes(net, test_loader, value_factor) else: mae, rmse = val(net, test_loader, value_factor) if not (args.warm_up and epoch < args.warm_up_steps): if args.scheduler == 'plt': scheduler.step(best_mae) elif args.scheduler != 'None': scheduler.step() if mae + 0.1 * rmse < best_mae + 0.1 * best_rmse: best_mae, best_rmse = mae, rmse torch.save(net.state_dict(), save_path) if args.loss == 'bayes': logger.info( '{} Epoch {}/{} Loss:{:.8f},MAE:{:.2f},RMSE:{:.2f} lr:{:.8f}, [CUR]:{mae:.1f}, {rmse:.1f}, [Best]:{b_mae:.1f}, {b_rmse:.1f}' .format(model_name, epoch + 1, args.epochs, train_loss / len(train_loader), epoch_mae.get_avg(), np.sqrt(epoch_mse.get_avg()), optimizer.param_groups[0]['lr'], mae=mae, rmse=rmse, b_mae=best_mae, b_rmse=best_rmse)) else: logger.info( '{} Epoch {}/{} Loss:{:.8f}, lr:{:.8f}, [CUR]:{mae:.1f}, {rmse:.1f}, [Best]:{b_mae:.1f}, {b_rmse:.1f}' .format(model_name, epoch + 1, args.epochs, train_loss / len(train_loader), optimizer.param_groups[0]['lr'], mae=mae, rmse=rmse, b_mae=best_mae, b_rmse=best_rmse))
print('Preparing data done.') # net print('==> Building model..') net = network.UNet(channels=args.channels) net = net.to(device) writer = SummaryWriter('runs/eventcamera_experiment_' + str(args.channels) + ('_fixed' if args.fixed else '')) print('Building model done.') test_output_image = np.zeros((len(test_label), 180, 240), dtype='float') if not os.path.exists('result'): os.mkdir('result') criterion = pytorch_ssim.SSIM(window_size=11) optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=0.001) if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('checkpoint/ckpt' + str(args.channels) + ('_fixed' if args.fixed else '') + '.pth') net.load_state_dict(checkpoint['net_params']) optimizer.load_state_dict(checkpoint['optimizer']) best_psnr = checkpoint['psnr'] best_ssim = checkpoint['ssim'] start_epoch = checkpoint['epoch'] scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 35],
def train(args): writer = SummaryWriter(comment=args.writer) os.makedirs(args.checkpoint_save_path, exist_ok=True) argsDict = args.__dict__ for k, v in argsDict.items(): writer.add_text('hyperparameter', '{} : {}'.format(str(k), str(v))) print_freq = args.print_freq test_freq = 1 global device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) input_shape = (3, args.img_height, args.img_width) net = dispnetcorr(args.maxdisp) G_AB = GeneratorResNet(input_shape, 2) G_BA = GeneratorResNet(input_shape, 2) D_A = Discriminator(3) D_B = Discriminator(3) if args.load_checkpoints: if args.load_from_mgpus_model: if args.load_dispnet_path: net = load_multi_gpu_checkpoint(net, args.load_dispnet_path, 'model') else: net.apply(weights_init_normal) G_AB = load_multi_gpu_checkpoint(G_AB, args.load_gan_path, 'G_AB') G_BA = load_multi_gpu_checkpoint(G_BA, args.load_gan_path, 'G_BA') D_A = load_multi_gpu_checkpoint(D_A, args.load_gan_path, 'D_A') D_B = load_multi_gpu_checkpoint(D_B, args.load_gan_path, 'D_B') else: if args.load_dispnet_path: net = load_checkpoint(net, args.load_checkpoint_path, device) else: net.apply(weights_init_normal) G_AB = load_checkpoint(G_AB, args.load_gan_path, 'G_AB') G_BA = load_checkpoint(G_BA, args.load_gan_path, 'G_BA') D_A = load_checkpoint(D_A, args.load_gan_path, 'D_A') D_B = load_checkpoint(D_B, args.load_gan_path, 'D_B') else: net.apply(weights_init_normal) G_AB.apply(weights_init_normal) G_BA.apply(weights_init_normal) D_A.apply(weights_init_normal) D_B.apply(weights_init_normal) # optimizer = optim.SGD(params, momentum=0.9) optimizer = optim.Adam(net.parameters(), lr=args.lr_rate, betas=(0.9, 0.999)) optimizer_G = optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=args.lr_gan, betas=(0.5, 0.999)) optimizer_D_A = optim.Adam(D_A.parameters(), lr=args.lr_gan, betas=(0.5, 0.999)) optimizer_D_B = optim.Adam(D_B.parameters(), lr=args.lr_gan, betas=(0.5, 0.999)) if args.use_multi_gpu: print("Let's use", torch.cuda.device_count(), "GPUs!") net = nn.DataParallel(net, device_ids=list(range(args.use_multi_gpu))) G_AB = nn.DataParallel(G_AB, device_ids=list(range(args.use_multi_gpu))) G_BA = nn.DataParallel(G_BA, device_ids=list(range(args.use_multi_gpu))) D_A = nn.DataParallel(D_A, device_ids=list(range(args.use_multi_gpu))) D_B = nn.DataParallel(D_B, device_ids=list(range(args.use_multi_gpu))) net.to(device) G_AB.to(device) G_BA.to(device) D_A.to(device) D_B.to(device) criterion_GAN = torch.nn.MSELoss().cuda() criterion_identity = torch.nn.L1Loss().cuda() ssim_loss = pytorch_ssim.SSIM() # data loader if args.source_dataset == 'driving': dataset = ImageDataset(height=args.img_height, width=args.img_width) elif args.source_dataset == 'synthia': dataset = ImageDataset2(height=args.img_height, width=args.img_width) else: raise "No suportive dataset" trainloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) valdataset = ValJointImageDataset() valloader = torch.utils.data.DataLoader(valdataset, batch_size=args.test_batch_size, shuffle=False, num_workers=1) train_loss_meter = AverageMeter() val_loss_meter = AverageMeter() ## debug only #with torch.no_grad(): # l1_test_loss, out_val = val(valloader, net, G_AB, None, writer, epoch=0, board_save=True) # val_loss_meter.update(l1_test_loss) # print('Val epoch[{}/{}] loss: {}'.format(0, args.total_epochs, l1_test_loss)) print('begin training...') best_val_d1 = 1. best_val_epe = 100. for epoch in range(args.total_epochs): #net.train() #G_AB.train() n_iter = 0 running_loss = 0. t = time.time() # custom lr decay, or warm-up lr = args.lr_rate if epoch >= int(args.lrepochs.split(':')[0]): lr = lr / int(args.lrepochs.split(':')[1]) for param_group in optimizer.param_groups: param_group['lr'] = lr for i, batch in enumerate(trainloader): n_iter += 1 leftA = batch['leftA'].to(device) rightA = batch['rightA'].to(device) leftB = batch['leftB'].to(device) rightB = batch['rightB'].to(device) dispA = batch['dispA'].unsqueeze(1).float().to(device) dispB = batch['dispB'].to(device) out_shape = (leftA.size(0), 1, args.img_height // 16, args.img_width // 16) valid = torch.cuda.FloatTensor(np.ones(out_shape)) fake = torch.cuda.FloatTensor(np.zeros(out_shape)) if i % args.train_ratio_gan == 0: # train generators G_AB.train() G_BA.train() net.eval() optimizer_G.zero_grad() # Identity loss loss_id_A = (criterion_identity(G_BA(leftA), leftA) + criterion_identity(G_BA(rightA), rightA)) / 2 loss_id_B = (criterion_identity(G_AB(leftB), leftB) + criterion_identity(G_AB(rightB), rightB)) / 2 loss_id = (loss_id_A + loss_id_B) / 2 if args.lambda_warp_inv: fake_leftB, fake_leftB_feats = G_AB(leftA, extract_feat=True) fake_leftA, fake_leftA_feats = G_BA(leftB, extract_feat=True) else: fake_leftB = G_AB(leftA) fake_leftA = G_BA(leftB) if args.lambda_warp: fake_rightB, fake_rightB_feats = G_AB(rightA, extract_feat=True) fake_rightA, fake_rightA_feats = G_BA(rightB, extract_feat=True) else: fake_rightB = G_AB(rightA) fake_rightA = G_BA(rightB) loss_GAN_AB = criterion_GAN(D_B(fake_leftB), valid) loss_GAN_BA = criterion_GAN(D_A(fake_leftA), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 if args.lambda_warp_inv: rec_leftA, rec_leftA_feats = G_BA(fake_leftB, extract_feat=True) else: rec_leftA = G_BA(fake_leftB) if args.lambda_warp: rec_rightA, rec_rightA_feats = G_BA(fake_rightB, extract_feat=True) else: rec_rightA = G_BA(fake_rightB) rec_leftB = G_AB(fake_leftA) rec_rightB = G_AB(fake_rightA) loss_cycle_A = (criterion_identity(rec_leftA, leftA) + criterion_identity(rec_rightA, rightA)) / 2 loss_ssim_A = 1. - (ssim_loss(rec_leftA, leftA) + ssim_loss(rec_rightA, rightA)) / 2 loss_cycle_B = (criterion_identity(rec_leftB, leftB) + criterion_identity(rec_rightB, rightB)) / 2 loss_ssim_B = 1. - (ssim_loss(rec_leftB, leftB) + ssim_loss(rec_rightB, rightB)) / 2 loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 loss_ssim = (loss_ssim_A + loss_ssim_B) / 2 # mode seeking loss if args.lambda_ms: loss_ms = G_AB(leftA, zx=True, zx_relax=True).mean() else: loss_ms = 0 # warping loss if args.lambda_warp_inv: fake_leftB_warp, loss_warp_inv_feat1 = G_AB( rightA, -dispA, True, [x.detach() for x in fake_leftB_feats]) rec_leftA_warp, loss_warp_inv_feat2 = G_BA( fake_rightB, -dispA, True, [x.detach() for x in rec_leftA_feats]) loss_warp_inv1 = warp_loss( [(G_BA(fake_leftB_warp[0]), fake_leftB_warp[1])], [leftA], weights=[1]) loss_warp_inv2 = warp_loss([rec_leftA_warp], [leftA], weights=[1]) loss_warp_inv = loss_warp_inv1 + loss_warp_inv2 + loss_warp_inv_feat1.mean( ) + loss_warp_inv_feat2.mean() else: loss_warp_inv = 0 if args.lambda_warp: fake_rightB_warp, loss_warp_feat1 = G_AB( leftA, dispA, True, [x.detach() for x in fake_rightB_feats]) rec_rightA_warp, loss_warp_feat2 = G_BA( fake_leftB, dispA, True, [x.detach() for x in rec_rightA_feats]) loss_warp1 = warp_loss( [(G_BA(fake_rightB_warp[0]), fake_rightB_warp[1])], [rightA], weights=[1]) loss_warp2 = warp_loss([rec_rightA_warp], [rightA], weights=[1]) loss_warp = loss_warp1 + loss_warp2 + loss_warp_feat1.mean( ) + loss_warp_feat2.mean() else: loss_warp = 0 # corr loss if args.lambda_corr: corrB = net(leftB, rightB, extract_feat=True) corrB1 = net(leftB, rec_rightB, extract_feat=True) corrB2 = net(rec_leftB, rightB, extract_feat=True) corrB3 = net(rec_leftB, rec_rightB, extract_feat=True) loss_corr = (criterion_identity(corrB1, corrB) + criterion_identity(corrB2, corrB) + criterion_identity(corrB3, corrB)) / 3 else: loss_corr = 0. lambda_ms = args.lambda_ms * (args.total_epochs - epoch) / args.total_epochs loss_G = loss_GAN + args.lambda_cycle*(args.alpha_ssim*loss_ssim+(1-args.alpha_ssim)*loss_cycle) + args.lambda_id*loss_id \ + args.lambda_warp*loss_warp + args.lambda_warp_inv*loss_warp_inv + args.lambda_corr*loss_corr + lambda_ms*loss_ms loss_G.backward() optimizer_G.step() # train discriminators. A: real, B: syn optimizer_D_A.zero_grad() loss_real_A = criterion_GAN(D_A(leftA), valid) fake_leftA.detach_() loss_fake_A = criterion_GAN(D_A(fake_leftA), fake) loss_D_A = (loss_real_A + loss_fake_A) / 2 loss_D_A.backward() optimizer_D_A.step() optimizer_D_B.zero_grad() #loss_real_B = criterion_GAN(D_B(torch.cat([syn_left_img, syn_right_img], 0)), valid) #fake_syn_left.detach_() #fake_syn_right.detach_() #loss_fake_B = criterion_GAN(D_B(torch.cat([fake_syn_left, fake_syn_right], 0)), fake) loss_real_B = criterion_GAN(D_B(leftB), valid) fake_leftB.detach_() loss_fake_B = criterion_GAN(D_B(fake_leftB), fake) loss_D_B = (loss_real_B + loss_fake_B) / 2 loss_D_B.backward() optimizer_D_B.step() # train disp net net.train() G_AB.eval() G_BA.eval() optimizer.zero_grad() disp_ests = net(G_AB(leftA), G_AB.forward(rightA)) mask = (dispA < args.maxdisp) & (dispA > 0) loss0 = model_loss0(disp_ests, dispA, mask) if args.lambda_disp_warp_inv: disp_warp = [-disp_ests[i] for i in range(3)] loss_disp_warp_inv = G_BA( rightB, disp_warp, True, [x.detach() for x in fake_leftA_feats]) loss_disp_warp_inv = loss_disp_warp_inv.mean() else: loss_disp_warp_inv = 0 if args.lambda_disp_warp: disp_warp = [disp_ests[i] for i in range(3)] loss_disp_warp = G_BA(leftB, disp_warp, True, [x.detach() for x in fake_rightA_feats]) loss_disp_warp = loss_disp_warp.mean() else: loss_disp_warp = 0 loss = loss0 + args.lambda_disp_warp * loss_disp_warp + args.lambda_disp_warp_inv * loss_disp_warp_inv loss.backward() optimizer.step() if i % print_freq == print_freq - 1: print('epoch[{}/{}] step[{}/{}] loss: {}'.format( epoch, args.total_epochs, i, len(trainloader), loss.item())) train_loss_meter.update(running_loss / print_freq) #writer.add_scalar('loss/trainloss avg_meter', train_loss_meter.val, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_disp', loss0, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_disp_warp', loss_disp_warp, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_disp_warp_inv', loss_disp_warp_inv, train_loss_meter.count * print_freq) if i % args.train_ratio_gan == 0: writer.add_scalar('loss/loss_G', loss_G, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_gan', loss_GAN, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_cycle', loss_cycle, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_id', loss_id, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_warp', loss_warp, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_warp_inv', loss_warp_inv, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_corr', loss_corr, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_ms', loss_ms, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_D_A', loss_D_A, train_loss_meter.count * print_freq) writer.add_scalar('loss/loss_D_B', loss_D_B, train_loss_meter.count * print_freq) imgA_visual = vutils.make_grid(leftA[:4, :, :, :], nrow=1, normalize=True, scale_each=True) fakeB_visual = vutils.make_grid(fake_leftB[:4, :, :, :], nrow=1, normalize=True, scale_each=True) recA_visual = vutils.make_grid(rec_leftA[:4, :, :, :], nrow=1, normalize=True, scale_each=True) rightA_visual = vutils.make_grid(rightA[:4, :, :, :], nrow=1, normalize=True, scale_each=True) fakeB_R_visual = vutils.make_grid(fake_rightB[:4, :, :, :], nrow=1, normalize=True, scale_each=True) recA_R_visual = vutils.make_grid(rec_rightA[:4, :, :, :], nrow=1, normalize=True, scale_each=True) imgB_visual = vutils.make_grid(leftB[:4, :, :, :], nrow=1, normalize=True, scale_each=True) fakeA_visual = vutils.make_grid(fake_leftA[:4, :, :, :], nrow=1, normalize=True, scale_each=True) recB_visual = vutils.make_grid(rec_leftB[:4, :, :, :], nrow=1, normalize=True, scale_each=True) rightB_visual = vutils.make_grid(rightB[:4, :, :, :], nrow=1, normalize=True, scale_each=True) fakeA_R_visual = vutils.make_grid(fake_rightA[:4, :, :, :], nrow=1, normalize=True, scale_each=True) recB_R_visual = vutils.make_grid(rec_rightB[:4, :, :, :], nrow=1, normalize=True, scale_each=True) writer.add_image('ABA_L/imgA', imgA_visual, i) writer.add_image('ABA_L/fakeB', fakeB_visual, i) writer.add_image('ABA_L/recA', recA_visual, i) writer.add_image('ABA_R/imgA', rightA_visual, i) writer.add_image('ABA_R/fakeB', fakeB_R_visual, i) writer.add_image('ABA_R/recA', recA_R_visual, i) writer.add_image('BAB_L/imgB', imgB_visual, i) writer.add_image('BAB_L/fakeA', fakeA_visual, i) writer.add_image('BAB_L/recB', recB_visual, i) writer.add_image('BAB_R/imgB', rightB_visual, i) writer.add_image('BAB_R/fakeA', fakeA_R_visual, i) writer.add_image('BAB_R/recB', recB_R_visual, i) if args.lambda_warp_inv: recA_warp_visual = vutils.make_grid( rec_leftA_warp[0][:4, :, :, :], nrow=1, normalize=True, scale_each=True) fakeB_warp_visual = vutils.make_grid( fake_leftB_warp[0][:4, :, :, :], nrow=1, normalize=True, scale_each=True) writer.add_image('warp/recA_L_warp', recA_warp_visual, i) writer.add_image('warp/fakeB_L_warp', fakeB_warp_visual, i) if args.lambda_warp: writer.add_image('warp/recA_R_warp', recA_warp_R_visual, i) writer.add_image('warp/fakeB_R_warp', fakeB_warp_R_visual, i) recA_warp_R_visual = vutils.make_grid( rec_rightA_warp[0][:4, :, :, :], nrow=1, normalize=True, scale_each=True) fakeB_warp_R_visual = vutils.make_grid( fake_rightB_warp[0][:4, :, :, :], nrow=1, normalize=True, scale_each=True) with torch.no_grad(): EPE, D1 = val(valloader, net, writer, epoch=epoch, board_save=True) t1 = time.time() print('epoch:{}, D1:{:.6f}, EPE:{:.6f}, cost time:{} '.format( epoch, D1, EPE, t1 - t)) if (epoch % args.save_interval == 0) or D1 < best_val_d1 or EPE < best_val_epe: best_val_d1 = D1 best_val_epe = EPE torch.save( { 'epoch': epoch, 'G_AB': G_AB.state_dict(), 'G_BA': G_BA.state_dict(), 'D_A': D_A.state_dict(), 'D_B': D_B.state_dict(), 'model': net.state_dict(), 'optimizer_DA_state_dict': optimizer_D_A.state_dict(), 'optimizer_DB_state_dict': optimizer_D_B.state_dict(), 'optimizer_G_state_dict': optimizer_G.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, args.checkpoint_save_path + '/ep' + str(epoch) + '_D1_{:.4f}_EPE{:.4f}'.format(D1, EPE) + '.pth.rar')