def train(args, train_loader, pose_exp_net, optimizer, epoch_size, logger, tb_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, tgt_lf, ref_imgs, ref_lfs, intrinsics, intrinsics_inv, pose_gt) in enumerate(train_loader): data_time.update(time.time() - end) tgt_lf = tgt_lf.to(device) ref_lfs = [lf.to(device) for lf in ref_lfs] pose_gt = pose_gt.to(device) explainability_mask, pose = pose_exp_net(tgt_lf, ref_lfs) loss = (pose - pose_gt).abs().mean() losses.update(loss.item(), args.batch_size) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() logger.train_bar.update(i+1) tb_writer.add_scalar('loss/train', loss, n_iter) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, output_writers=[]): global device batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = len(output_writers) > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) # compute output disp = disp_net(tgt_img) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() if log_outputs and i < len( output_writers): # log first output of first batches if epoch == 0: for j, ref in enumerate(ref_imgs): output_writers[i].add_image('val Input {}'.format(j), tensor2array(tgt_img[0]), 0) output_writers[i].add_image('val Input {}'.format(j), tensor2array(ref[0]), 1) log_output_tensorboard(output_writers[i], 'val', '', epoch, 1. / disp, disp, warped, diff, explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): output_writers[0].add_histogram( '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) output_writers[0].add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def train(train_loader, alice_net, bob_net, mod_net, optimizer, epoch_size, logger=None, train_writer=None, mode='compete'): global args, n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) # switch to train mode alice_net.train() bob_net.train() mod_net.train() end = time.time() for i, (img, target) in enumerate(train_loader): # measure data loading time #mode = 'compete' if (i%2)==0 else 'collaborate' data_time.update(time.time() - end) img_var = Variable(img.cuda()) target_var = Variable(target.cuda()) pred_alice = alice_net(img_var) pred_bob = bob_net(img_var) pred_mod = mod_net(img_var) loss_alice = F.cross_entropy(pred_alice, target_var, reduce=False) loss_bob = F.cross_entropy(pred_bob, target_var, reduce=False) if mode == 'compete': if args.fix_bob: if args.DEBUG: print("Training Alice Only") loss = loss_alice.mean() elif args.fix_alice: loss = loss_bob.mean() else: if args.DEBUG: print("Training Both Alice and Bob") pred_mod_soft = Variable(F.sigmoid(pred_mod).data, requires_grad=False) loss = pred_mod_soft * loss_alice + (1 - pred_mod_soft) * loss_bob loss = loss.mean() elif mode == 'collaborate': loss_alice2 = Variable(loss_alice.data, requires_grad=False) loss_bob2 = Variable(loss_bob.data, requires_grad=False) loss1 = F.sigmoid(pred_mod) * loss_alice2 + ( 1 - F.sigmoid(pred_mod)) * loss_bob2 loss2 = collaboration_loss(pred_mod, loss_alice2, loss_bob2) loss = loss1.mean() + loss2.mean( ) + args.wr * mod_regularization_loss(pred_mod) if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('loss_alice', loss_alice.mean().item(), n_iter) train_writer.add_scalar('loss_bob', loss_bob.mean().item(), n_iter) train_writer.add_scalar('mod_mean', F.sigmoid(pred_mod).mean().item(), n_iter) train_writer.add_scalar('mod_var', F.sigmoid(pred_mod).var().item(), n_iter) train_writer.add_scalar('loss_regularization', mod_regularization_loss(pred_mod).item(), n_iter) if mode == 'compete': train_writer.add_scalar('competetion_loss', loss.item(), n_iter) elif mode == 'collaborate': train_writer.add_scalar('collaboration_loss', loss.item(), n_iter) # record loss losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_alice.mean().item(), loss_bob.mean().item() ]) if args.log_terminal: logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write( 'Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_with_gt(args, val_loader, mvdnet, depth_cons, epoch, output_writers=[]): batch_time = AverageMeter() error_names = [ 'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'mean_angle' ] test_error_names = [ 'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3', 'mean_angle' ] test_error_names1 = [ 'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3', 'mean_angle' ] errors = AverageMeter(i=len(error_names)) test_errors = AverageMeter(i=len(test_error_names)) test_errors1 = AverageMeter(i=len(test_error_names1)) log_outputs = len(output_writers) > 0 # switch to evaluate mode if args.train_cons: depth_cons.eval() else: mvdnet.eval() end = time.time() with torch.no_grad(): for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv, tgt_depth) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] gt_nmap_var = Variable(gt_nmap.cuda()) ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) tgt_depth_var = Variable(tgt_depth.cuda()) pose = torch.cat(ref_poses_var, 1) if (pose != pose).any(): continue outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var) output_depth = outputs[0] output_depth1 = output_depth.clone() nmap = outputs[1] nmap1 = nmap.clone() output_depth1 = output_depth.clone() if args.train_cons: outputs = depth_cons(output_depth, nmap.permute(0, 3, 1, 2)) nmap = outputs[:, 1:].permute(0, 2, 3, 1) output_depth = outputs[:, 0].unsqueeze(1) mask = (tgt_depth <= args.nlabel * args.mindepth) & ( tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth) #mask = (tgt_depth <= 10) & (tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth) #for DeMoN testing, to compare against DPSNet you might need to turn on this for fair comparison if not mask.any(): continue output_depth1_ = torch.squeeze(output_depth1.data.cpu(), 1) output_depth_ = torch.squeeze(output_depth.data.cpu(), 1) errors_ = compute_errors_train(tgt_depth, output_depth_, mask) test_errors_ = list( compute_errors_test(tgt_depth[mask], output_depth_[mask])) test_errors1_ = list( compute_errors_test(tgt_depth[mask], output_depth1_[mask])) n_mask = (gt_nmap_var.permute(0, 2, 3, 1)[0, :, :] != 0) n_mask = n_mask[:, :, 0] | n_mask[:, :, 1] | n_mask[:, :, 2] total_angles_m = compute_angles( gt_nmap_var.permute(0, 2, 3, 1)[0], nmap[0]) total_angles_m1 = compute_angles( gt_nmap_var.permute(0, 2, 3, 1)[0], nmap1[0]) mask_angles = total_angles_m[n_mask] mask_angles1 = total_angles_m1[n_mask] total_angles_m[~n_mask] = 0 total_angles_m1[~n_mask] = 0 errors_.append( torch.mean(mask_angles).item() ) #/mask_angles.size(0)#[torch.sum(mask_angles).item(), (mask_angles.size(0)), torch.sum(mask_angles < 7.5).item(), torch.sum(mask_angles < 15).item(), torch.sum(mask_angles < 30).item(), torch.sum(mask_angles < 45).item()] test_errors_.append(torch.mean(mask_angles).item()) test_errors1_.append(torch.mean(mask_angles1).item()) errors.update(errors_) test_errors.update(test_errors_) test_errors1.update(test_errors1_) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 or i == len(val_loader) - 1: if args.train_cons: print( 'valid: Time {} Prev Error {:.4f}({:.4f}) Curr Error {:.4f} ({:.4f}) Prev angle Error {:.4f} ({:.4f}) Curr angle Error {:.4f} ({:.4f}) Iter {}/{}' .format(batch_time, test_errors1.val[0], test_errors1.avg[0], test_errors.val[0], test_errors.avg[0], test_errors1.val[-1], test_errors1.avg[-1], test_errors.val[-1], test_errors.avg[-1], i, len(val_loader))) else: print( 'valid: Time {} Rel Error {:.4f} ({:.4f}) Angle Error {:.4f} ({:.4f}) Iter {}/{}' .format(batch_time, test_errors.val[0], test_errors.avg[0], test_errors.val[-1], test_errors.avg[-1], i, len(val_loader))) if args.output_print: output_dir = Path(args.output_dir) if not os.path.isdir(output_dir): os.mkdir(output_dir) plt.imsave(output_dir / '{:04d}_map{}'.format(i, '_dps.png'), output_depth_.numpy()[0], cmap='rainbow') np.save(output_dir / '{:04d}{}'.format(i, '_dps.npy'), output_depth_.numpy()[0]) if args.train_cons: plt.imsave(output_dir / '{:04d}_map{}'.format(i, '_prev.png'), output_depth1_.numpy()[0], cmap='rainbow') np.save(output_dir / '{:04d}{}'.format(i, '_prev.npy'), output_depth1_.numpy()[0]) # np.save(output_dir/'{:04d}{}'.format(i,'_gt.npy'),tgt_depth.numpy()[0]) # imsave(output_dir/'{:04d}_aimage{}'.format(i,'.png'), np.transpose(tgt_img.numpy()[0],(1,2,0))) # np.save(output_dir/'{:04d}_cam{}'.format(i,'.npy'),intrinsics_var.cpu().numpy()[0]) if args.output_print: np.savetxt(output_dir / args.ttype + 'errors.csv', test_errors.avg, fmt='%1.4f', delimiter=',') np.savetxt(output_dir / args.ttype + 'prev_errors.csv', test_errors1.avg, fmt='%1.4f', delimiter=',') return errors.avg, error_names
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output disparities = disp_net(tgt_img) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1 = photometric_reconstruction_loss(tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) with torch.no_grad(): for k, scaled_depth in enumerate(depth): train_writer.add_image( 'train Dispnet Output Normalized {}'.format(k), tensor2array(disparities[k][0], max_value=None, colormap='magma'), n_iter) train_writer.add_image( 'train Depth Output Normalized {}'.format(k), tensor2array(1 / disparities[k][0], max_value=None), n_iter) b, _, h, w = scaled_depth.size() downscale = tgt_img.size(2) / h tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') ref_imgs_scaled = [ F.interpolate(ref_img, (h, w), mode='area') for ref_img in ref_imgs ] intrinsics_scaled = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_scaled): ref_warped = inverse_warp( ref, scaled_depth[:, 0], pose[:, j], intrinsics_scaled, rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] train_writer.add_image( 'train Warped Outputs {} {}'.format(k, j), tensor2array(ref_warped), n_iter) train_writer.add_image( 'train Diff Outputs {} {}'.format(k, j), tensor2array( 0.5 * (tgt_img_scaled[0] - ref_warped).abs()), n_iter) if explainability_mask[k] is not None: train_writer.add_image( 'train Exp mask Outputs {} {}'.format(k, j), tensor2array(explainability_mask[k][0, j], max_value=1, colormap='bone'), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def train(train_loader, model, optimizer, epoch, args, log): '''train given model and dataloader''' batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() mixing_avg = [] # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(train_loader): data_time.update(time.time() - end) optimizer.zero_grad() input = input.cuda() target = target.long().cuda() unary = None noise = None adv_mask1 = 0 adv_mask2 = 0 # train with clean images if args.train == 'vanilla': input_var, target_var = Variable(input), Variable(target) output, reweighted_target = model(input_var, target_var) loss = bce_loss(softmax(output), reweighted_target) # train with mixup images elif args.train == 'mixup': # process for Puzzle Mix if args.graph: # whether to add adversarial noise or not if args.adv_p > 0: adv_mask1 = np.random.binomial(n=1, p=args.adv_p) adv_mask2 = np.random.binomial(n=1, p=args.adv_p) else: adv_mask1 = 0 adv_mask2 = 0 # random start if (adv_mask1 == 1 or adv_mask2 == 1): noise = torch.zeros_like(input).uniform_( -args.adv_eps / 255., args.adv_eps / 255.) input_orig = input * args.std + args.mean input_noise = input_orig + noise input_noise = torch.clamp(input_noise, 0, 1) noise = input_noise - input_orig input_noise = (input_noise - args.mean) / args.std input_var = Variable(input_noise, requires_grad=True) else: input_var = Variable(input, requires_grad=True) target_var = Variable(target) # calculate saliency (unary) if args.clean_lam == 0: model.eval() output = model(input_var) loss_batch = criterion_batch(output, target_var) else: model.train() output = model(input_var) loss_batch = 2 * args.clean_lam * criterion_batch( output, target_var) / args.num_classes loss_batch_mean = torch.mean(loss_batch, dim=0) loss_batch_mean.backward(retain_graph=True) unary = torch.sqrt(torch.mean(input_var.grad**2, dim=1)) # calculate adversarial noise if (adv_mask1 == 1 or adv_mask2 == 1): noise += (args.adv_eps + 2) / 255. * input_var.grad.sign() noise = torch.clamp(noise, -args.adv_eps / 255., args.adv_eps / 255.) adv_mix_coef = np.random.uniform(0, 1) noise = adv_mix_coef * noise if args.clean_lam == 0: model.train() optimizer.zero_grad() input_var, target_var = Variable(input), Variable(target) # perform mixup and calculate loss output, reweighted_target = model(input_var, target_var, mixup=True, args=args, grad=unary, noise=noise, adv_mask1=adv_mask1, adv_mask2=adv_mask2) loss = bce_loss(softmax(output), reweighted_target) # for manifold mixup elif args.train == 'mixup_hidden': input_var, target_var = Variable(input), Variable(target) output, reweighted_target = model(input_var, target_var, mixup_hidden=True, args=args) loss = bce_loss(softmax(output), reweighted_target) else: raise AssertionError('wrong train type!!') # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() print_log( ' **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}' .format(top1=top1, top5=top5, error1=100 - top1.avg), log) return top1.avg, top5.avg, losses.avg
def main(): # set up the experiment directories if not args.log_off: exp_name = experiment_name_non_mnist() exp_dir = os.path.join(args.root_dir, exp_name) if not os.path.exists(exp_dir): os.makedirs(exp_dir) copy_script_to_folder(os.path.abspath(__file__), exp_dir) result_png_path = os.path.join(exp_dir, 'results.png') log = open(os.path.join(exp_dir, 'log.txt'.format(args.seed)), 'w') print_log('save path : {}'.format(exp_dir), log) else: log = None global best_acc state = {k: v for k, v in args._get_kwargs()} print("") print_log(state, log) print("") print_log("Random Seed: {}".format(args.seed), log) print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) print_log("torch version : {}".format(torch.__version__), log) print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) # dataloader train_loader, valid_loader, _, test_loader, num_classes = load_data_subset( args.batch_size, 2, args.dataset, args.data_dir, labels_per_class=args.labels_per_class, valid_labels_per_class=args.valid_labels_per_class, mixup_alpha=args.mixup_alpha) if args.dataset == 'tiny-imagenet-200': stride = 2 args.mean = torch.tensor([0.5] * 3, dtype=torch.float32).view(1, 3, 1, 1).cuda() args.std = torch.tensor([0.5] * 3, dtype=torch.float32).view(1, 3, 1, 1).cuda() args.labels_per_class = 500 elif args.dataset == 'cifar10': stride = 1 args.mean = torch.tensor([x / 255 for x in [125.3, 123.0, 113.9]], dtype=torch.float32).view(1, 3, 1, 1).cuda() args.std = torch.tensor([x / 255 for x in [63.0, 62.1, 66.7]], dtype=torch.float32).view(1, 3, 1, 1).cuda() args.labels_per_class = 5000 elif args.dataset == 'cifar100': stride = 1 args.mean = torch.tensor([x / 255 for x in [129.3, 124.1, 112.4]], dtype=torch.float32).view(1, 3, 1, 1).cuda() args.std = torch.tensor([x / 255 for x in [68.2, 65.4, 70.4]], dtype=torch.float32).view(1, 3, 1, 1).cuda() args.labels_per_class = 500 else: raise AssertionError('Given Dataset is not supported!') # create model print_log("=> creating model '{}'".format(args.arch), log) net = models.__dict__[args.arch](num_classes, args.dropout, stride).cuda() args.num_classes = num_classes net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) recorder = RecorderMeter(args.epochs) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print_log("=> loading checkpoint '{}'".format(args.resume), log) checkpoint = torch.load(args.resume) recorder = checkpoint['recorder'] args.start_epoch = checkpoint['epoch'] net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) best_acc = recorder.max_accuracy(False) print_log( "=> loaded checkpoint '{}' accuracy={} (epoch {})".format( args.resume, best_acc, checkpoint['epoch']), log) else: print_log("=> no checkpoint found at '{}'".format(args.resume), log) else: print_log( "=> do not use any checkpoint for {} model".format(args.arch), log) if args.evaluate: validate(test_loader, net, criterion, log) return start_time = time.time() epoch_time = AverageMeter() train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(args.start_epoch, args.epochs): current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule) if epoch == args.schedule[0]: args.clean_lam == 0 need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \ + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log) # train for one epoch tr_acc, tr_acc5, tr_los = train(train_loader, net, optimizer, epoch, args, log) # evaluate on validation set val_acc, val_los = validate(test_loader, net, log) if (epoch % 50) == 0 and args.adv_p > 0: _, _ = validate(test_loader, net, log, fgsm=True, eps=4, mean=args.mean, std=args.std) _, _ = validate(test_loader, net, log, fgsm=True, eps=8, mean=args.mean, std=args.std) train_loss.append(tr_los) train_acc.append(tr_acc) test_loss.append(val_los) test_acc.append(val_acc) is_best = False if val_acc > best_acc: is_best = True best_acc = val_acc # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() if args.log_off: continue # save log save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': net.state_dict(), 'recorder': recorder, 'optimizer': optimizer.state_dict(), }, is_best, exp_dir, 'checkpoint.pth.tar') dummy = recorder.update(epoch, tr_los, tr_acc, val_los, val_acc) if (epoch + 1) % 100 == 0: recorder.plot_curve(result_png_path) train_log = OrderedDict() train_log['train_loss'] = train_loss train_log['train_acc'] = train_acc train_log['test_loss'] = test_loss train_log['test_acc'] = test_acc pickle.dump(train_log, open(os.path.join(exp_dir, 'log.pkl'), 'wb')) plotting(exp_dir) acc_var = np.maximum( np.max(test_acc[-10:]) - np.median(test_acc[-10:]), np.median(test_acc[-10:]) - np.min(test_acc[-10:])) print_log( "\nfinal 10 epoch acc (median) : {:.2f} (+- {:.2f})".format( np.median(test_acc[-10:]), acc_var), log) if not args.log_off: log.close()
def validate_with_gt(args, val_loader, mvdnet, epoch, output_writers=[]): batch_time = AverageMeter() error_names = [ 'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'mean_angle' ] test_error_names = [ 'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3', 'mean_angle' ] errors = AverageMeter(i=len(error_names)) test_errors = AverageMeter(i=len(test_error_names)) log_outputs = len(output_writers) > 0 output_dir = Path(args.output_dir) if not os.path.isdir(output_dir): os.mkdir(output_dir) # switch to evaluate mode mvdnet.eval() end = time.time() with torch.no_grad(): for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv, tgt_depth) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] gt_nmap_var = Variable(gt_nmap.cuda()) ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) tgt_depth_var = Variable(tgt_depth.cuda()) pose = torch.cat(ref_poses_var, 1) if (pose != pose).any(): continue if args.dataset == 'sceneflow': factor = (1.0 / args.scale) * intrinsics_var[:, 0, 0] / 1050.0 factor = factor.view(-1, 1, 1) else: factor = torch.ones( (tgt_depth_var.size(0), 1, 1)).type_as(tgt_depth_var) # get mask mask = (tgt_depth_var <= args.nlabel * args.mindepth * factor * 3) & (tgt_depth_var >= args.mindepth * factor) & ( tgt_depth_var == tgt_depth_var) if not mask.any(): continue output_depth, nmap = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var, factor=factor.unsqueeze(1)) output_disp = args.nlabel * args.mindepth / (output_depth) if args.dataset == 'sceneflow': output_disp = (args.nlabel * args.mindepth) * 3 / (output_depth) output_depth = (args.nlabel * 3) * (args.mindepth * factor) / output_disp tgt_disp_var = ((1.0 / args.scale) * intrinsics_var[:, 0, 0].view(-1, 1, 1) / tgt_depth_var) if args.dataset == 'sceneflow': output = torch.squeeze(output_disp.data.cpu(), 1) errors_ = compute_errors_train(tgt_disp_var.cpu(), output, mask) test_errors_ = list( compute_errors_test(tgt_disp_var.cpu()[mask], output[mask])) else: output = torch.squeeze(output_depth.data.cpu(), 1) errors_ = compute_errors_train(tgt_depth, output, mask) test_errors_ = list( compute_errors_test(tgt_depth[mask], output[mask])) n_mask = (gt_nmap_var.permute(0, 2, 3, 1)[0, :, :] != 0) n_mask = n_mask[:, :, 0] | n_mask[:, :, 1] | n_mask[:, :, 2] total_angles_m = compute_angles( gt_nmap_var.permute(0, 2, 3, 1)[0], nmap[0]) mask_angles = total_angles_m[n_mask] total_angles_m[~n_mask] = 0 errors_.append( torch.mean(mask_angles).item() ) #/mask_angles.size(0)#[torch.sum(mask_angles).item(), (mask_angles.size(0)), torch.sum(mask_angles < 7.5).item(), torch.sum(mask_angles < 15).item(), torch.sum(mask_angles < 30).item(), torch.sum(mask_angles < 45).item()] test_errors_.append(torch.mean(mask_angles).item()) errors.update(errors_) test_errors.update(test_errors_) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.output_print: np.save(output_dir / '{:04d}{}'.format(i, '_depth.npy'), output.numpy()[0]) plt.imsave(output_dir / '{:04d}_gt{}'.format(i, '.png'), tgt_depth.numpy()[0], cmap='rainbow') imsave(output_dir / '{:04d}_aimage{}'.format(i, '.png'), np.transpose(tgt_img.numpy()[0], (1, 2, 0))) np.save(output_dir / '{:04d}_cam{}'.format(i, '.npy'), intrinsics_var.cpu().numpy()[0]) np.save(output_dir / '{:04d}{}'.format(i, '_normal.npy'), nmap.cpu().numpy()[0]) if i % args.print_freq == 0: print( 'valid: Time {} Abs Error {:.4f} ({:.4f}) Abs angle Error {:.4f} ({:.4f}) Iter {}/{}' .format(batch_time, test_errors.val[0], test_errors.avg[0], test_errors.val[-1], test_errors.avg[-1], i, len(val_loader))) if args.output_print: np.savetxt(output_dir / args.ttype + 'errors.csv', test_errors.avg, fmt='%1.4f', delimiter=',') np.savetxt(output_dir / args.ttype + 'angle_errors.csv', test_errors.avg, fmt='%1.4f', delimiter=',') return errors.avg, error_names
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, output_writers=[]): batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = len(output_writers) > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) # compute output disp = disp_net(tgt_img_var) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var) loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.data[0] if w2 > 0: loss_2 = explainability_loss(explainability_mask).data[0] else: loss_2 = 0 loss_3 = smooth_loss(disp).data[0] if log_outputs and i % 100 == 0 and i / 100 < len( output_writers): # log first output of every 100 batch index = int(i // 100) if epoch == 0: for j, ref in enumerate(ref_imgs): output_writers[index].add_image('val Input {}'.format(j), tensor2array(tgt_img[0]), 0) output_writers[index].add_image('val Input {}'.format(j), tensor2array(ref[0]), 1) output_writers[index].add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), epoch) output_writers[index].add_image( 'val Depth Output', tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_var): ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j], intrinsics_var[:1], intrinsics_inv_var[:1], rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] output_writers[index].add_image( 'val Warped Outputs {}'.format(j), tensor2array(ref_warped.data.cpu()), epoch) output_writers[index].add_image( 'val Diff Outputs {}'.format(j), tensor2array( 0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()), epoch) if explainability_mask is not None: output_writers[index].add_image( 'val Exp mask Outputs {}'.format(j), tensor2array(explainability_mask[0, j].data.cpu(), max_value=1, colormap='bone'), epoch) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.data.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): output_writers.add_histogram( '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) output_writers[0].add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def validate(val_loader, model, log): '''evaluate trained model''' losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # Switch to evaluate mode model.eval() for i, (input, target) in enumerate(val_loader): if args.use_cuda: input = input.cuda() target = target.cuda() with torch.no_grad(): output = model(input) target_reweighted = to_one_hot(target, args.num_classes) loss = bce_loss(softmax(output), target_reweighted) # Measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) print_log( '**Test ** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Loss: {losses.avg:.3f} ' .format(top1=top1, top5=top5, error1=100 - top1.avg, losses=losses), log) return top1.avg, losses.avg
def train(args, train_loader, mvdnet, optimizer, epoch_size, train_writer, epoch): global n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) d_losses = AverageMeter(precision=4) nmap_losses = AverageMeter(precision=4) # switch to training mode mvdnet.train() print("Training") end = time.time() for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv, tgt_depth) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] gt_nmap_var = Variable(gt_nmap.cuda()) ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) tgt_depth_var = Variable(tgt_depth.cuda()).cuda() # compute output pose = torch.cat(ref_poses_var, 1) if args.dataset == 'sceneflow': factor = (1.0 / args.scale) * intrinsics_var[:, 0, 0] / 1050.0 factor = factor.view(-1, 1, 1) else: factor = torch.ones( (tgt_depth_var.size(0), 1, 1)).type_as(tgt_depth_var) # get mask mask = (tgt_depth_var <= args.nlabel * args.mindepth * factor * 3) & ( tgt_depth_var >= args.mindepth * factor) & (tgt_depth_var == tgt_depth_var) mask.detach_() if mask.any() == 0: continue targetimg = inverse_warp(ref_imgs_var[0], tgt_depth_var.unsqueeze(1), pose[:, 0], intrinsics_var, intrinsics_inv_var) #[B,CH,D,H,W,1] outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var, factor=factor.unsqueeze(1)) nmap = outputs[2].permute(0, 3, 1, 2) depths = outputs[0:2] disps = [args.mindepth * args.nlabel / (depth) for depth in depths] # correct disps if args.dataset == 'sceneflow': disps = [(args.mindepth * args.nlabel) * 3 / (depth) for depth in depths] # correct disps depths = [(args.mindepth * factor) * (args.nlabel * 3) / disp for disp in disps] loss = 0. d_loss = 0. nmap_loss = 0. if args.dataset == 'sceneflow': tgt_disp_var = ((1.0 / args.scale) * intrinsics_var[:, 0, 0].view(-1, 1, 1) / tgt_depth_var) for l, disp in enumerate(disps): output = torch.squeeze(disp, 1) d_loss = d_loss + F.smooth_l1_loss(output[mask], tgt_disp_var[mask]) * pow( 0.7, len(disps) - l - 1) else: for l, depth in enumerate(depths): output = torch.squeeze(depth, 1) d_loss = d_loss + F.smooth_l1_loss(output[mask], tgt_depth_var[mask]) * pow( 0.7, len(depths) - l - 1) n_mask = mask.unsqueeze(1).expand(-1, 3, -1, -1) nmap_loss = nmap_loss + F.smooth_l1_loss(nmap[n_mask], gt_nmap_var[n_mask]) loss = loss + args.d_weight * d_loss + args.n_weight * nmap_loss if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('total_loss', loss.item(), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) d_losses.update(d_loss.item(), args.batch_size) nmap_losses.update(nmap_loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss.item()]) if i % args.print_freq == 0: print( 'Train: Time {} Data {} Loss {} NmapLoss {} DLoss {} Iter {}/{} Epoch {}/{}' .format(batch_time, data_time, losses, nmap_losses, d_losses, i, len(train_loader), epoch, args.epochs)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def train(train_loader, model, optimizer, epoch, args, log, mpp=None): '''train given model and dataloader''' batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() mixing_avg = [] # switch to train mode model.train() end = time.time() for input, target in train_loader: data_time.update(time.time() - end) optimizer.zero_grad() input = input.cuda() target = target.long().cuda() sc = None # train with clean images if not args.comix: target_reweighted = to_one_hot(target, args.num_classes) output = model(input) loss = bce_loss(softmax(output), target_reweighted) # train with Co-Mixup images else: input_var = Variable(input, requires_grad=True) target_var = Variable(target) A_dist = None # Calculate saliency (unary) if args.clean_lam == 0: model.eval() output = model(input_var) loss_batch = criterion_batch(output, target_var) else: model.train() output = model(input_var) loss_batch = 2 * args.clean_lam * criterion_batch( output, target_var) / args.num_classes loss_batch_mean = torch.mean(loss_batch, dim=0) loss_batch_mean.backward(retain_graph=True) sc = torch.sqrt(torch.mean(input_var.grad**2, dim=1)) # Here, we calculate distance between most salient location (Compatibility) # We can try various measurements with torch.no_grad(): z = F.avg_pool2d(sc, kernel_size=8, stride=1) z_reshape = z.reshape(args.batch_size, -1) z_idx_1d = torch.argmax(z_reshape, dim=1) z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device) z_idx_2d[:, 0] = z_idx_1d // z.shape[-1] z_idx_2d[:, 1] = z_idx_1d % z.shape[-1] A_dist = distance(z_idx_2d, dist_type='l1') if args.clean_lam == 0: model.train() optimizer.zero_grad() # Perform mixup and calculate loss target_reweighted = to_one_hot(target, args.num_classes) if args.parallel: device = input.device out, target_reweighted = mpp(input.cpu(), target_reweighted.cpu(), args=args, sc=sc.cpu(), A_dist=A_dist.cpu()) out = out.to(device) target_reweighted = target_reweighted.to(device) else: out, target_reweighted = mixup_process(input, target_reweighted, args=args, sc=sc, A_dist=A_dist) out = model(out) loss = bce_loss(softmax(out), target_reweighted) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() print_log( '**Train** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f}' .format(top1=top1, top5=top5, error1=100 - top1.avg), log) return top1.avg, top5.avg, losses.avg
def test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = None): #data prepared device = global_vars_dict['device'] n_iter_val = global_vars_dict['n_iter_val'] args = global_vars_dict['args'] data_time = AverageMeter() # to eval model disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros(((len(val_loader)-1) * 1 * (args.sequence_length-1),6))#init disp_list = [] flow_list = [] mask_list = [] #3. validation cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in tqdm(enumerate(val_loader)): data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics,intrinsics_inv = intrinsics.to(device),intrinsics_inv.to(device) #3.1 forwardpass #disp disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #pose pose = pose_net(tgt_img, ref_imgs) #flow---- #制作前后一帧的 if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) #FLOW FWD [B,2,H,W] #flow cam :tensor[b,2,h,w] #flow_background flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) #exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, # ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, # ws=args.smooth_loss_weight) rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()#[b,2,h,w] rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs() # mask # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) # 有效区域?4?? # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] end = time.time() #3.4 check log #查看forward pass效果 # 2 disp disp_to_show =tensor2array(disp[0].cpu(), max_value=None,colormap='bone')# tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image('Disp/disp0', disp_to_show,i) disp_list.append(disp_to_show) if i == 0: disp_arr = np.expand_dims(disp_to_show,axis=0) else: disp_to_show = np.expand_dims(disp_to_show,axis=0) disp_arr = np.concatenate([disp_arr,disp_to_show],0) #3. flow tb_writer.add_image('Flow/Flow Output', flow2rgb(flow_fwd[0], max_value=6),i) tb_writer.add_image('Flow/cam_Flow Output', flow2rgb(flow_cam[0], max_value=6),i) tb_writer.add_image('Flow/rigid_Flow Output', flow2rgb(rigidity_mask_fwd[0], max_value=6),i) tb_writer.add_image('Flow/rigidity_mask_fwd',flow2rgb(rigidity_mask_fwd[0],max_value=6),i) flow_list.append(flow2rgb(flow_fwd[0], max_value=6)) #4. mask tb_writer.add_image('Mask /mask0',tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), i) #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch) mask_list.append(tensor2array(explainability_mask[0][0], max_value=None, colormap='magma')) # return disp_list,disp_arr,flow_list,mask_list
def validate(args, val_loader, pose_exp_net, logger, tb_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, tgt_lf, ref_imgs, ref_lfs, intrinsics, intrinsics_inv, pose_gt) in enumerate(val_loader): data_time.update(time.time() - end) tgt_lf = tgt_lf.to(device) ref_lfs = [lf.to(device) for lf in ref_lfs] pose_gt = pose_gt.to(device) explainability_mask, pose = pose_exp_net(tgt_lf, ref_lfs) loss = (pose - pose_gt).abs().mean() losses.update(loss.item(), args.batch_size) batch_time.update(time.time() - end) logger.valid_bar.update(i+1) if i % args.print_freq == 0: logger.valid_writer.write('Validate: Time {} Data {} Loss {}'.format(batch_time, data_time, losses)) n_iter += 1 tb_writer.add_scalar('loss/valid', losses.avg[0], n_iter) return losses.avg[0]
def validate_Make3D(args, val_loader, model, epoch, logger, mode='DtoD'): ##global device batch_time = AverageMeter() error_names = ['abs_diff', 'abs_rel', 'ave_log10', 'rmse'] errors = AverageMeter(i=len(error_names)) min_errors = AverageMeter(i=len(error_names)) min_errors_list = [] abs_diff_tot, abs_rel_tot, ave_log10_tot, rmse_tot = [], [], [], [] abs_diff_sum, abs_rel_sum, ave_log10_sum, rmse_sum = 0, 0, 0, 0 # switch to evaluate mode #model.eval() print("mode: ", args.mode) end = time.time() logger.valid_bar.update(0) for i, (depth, img, depth_np) in enumerate(val_loader): img = img.cuda() depth = depth.cuda() depth_np = depth_np.cuda() # compute output if mode == 'RtoD' or mode == 'RtoD_test': input_img = img elif mode == 'DtoD' or mode == 'DtoD_test': input_img = depth with torch.no_grad(): output_depth = model(input_img, istrain=False) err_result = compute_errors_Make3D(depth_np, depth, output_depth) errors.update(err_result) abs_diff_tot.append(err_result[0]) abs_rel_tot.append(err_result[1]) ave_log10_tot.append(err_result[2]) rmse_tot.append(err_result[3]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} Abs Error {:.4f} ({:.4f})'.format( batch_time, errors.val[0], errors.avg[0])) logger.valid_bar.update(len(val_loader)) sorted_abs_diff = sorted(abs_diff_tot) #min_len = 72 min_len = (len(sorted_abs_diff)) print("scene length: ", min_len) print("sorted_abs_diff length: ", len(sorted_abs_diff)) for i in range(min_len): sort_idx = abs_diff_tot.index(sorted_abs_diff[i]) abs_diff_sum += sorted_abs_diff[i] abs_rel_sum += abs_rel_tot[sort_idx] ave_log10_sum += ave_log10_tot[sort_idx] rmse_sum += rmse_tot[sort_idx] min_errors_list.append(abs_diff_sum / min_len) min_errors_list.append(abs_rel_sum / min_len) min_errors_list.append(ave_log10_sum / min_len) min_errors_list.append(rmse_sum / min_len) min_errors.update(min_errors_list) return errors.avg, min_errors.avg, error_names
def main(): global args args = parser.parse_args() args.pretrained_disp = Path(args.pretrained_disp) args.pretrained_pose = Path(args.pretrained_pose) args.pretrained_mask = Path(args.pretrained_mask) args.pretrained_flow = Path(args.pretrained_flow) if args.output_dir is not None: args.output_dir = Path(args.output_dir) args.output_dir.makedirs_p() image_dir = args.output_dir / 'images' gt_dir = args.output_dir / 'gt' mask_dir = args.output_dir / 'mask' viz_dir = args.output_dir / 'viz' image_dir.makedirs_p() gt_dir.makedirs_p() mask_dir.makedirs_p() viz_dir.makedirs_p() output_writer = SummaryWriter(args.output_dir) normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) flow_loader_h, flow_loader_w = 256, 832 valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) if args.dataset == "kitti2015": val_flow_set = ValidationFlow(root=args.kitti_dir, sequence_length=5, transform=valid_flow_transform) val_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) disp_net = getattr(models, args.dispnet)().cuda() pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda() mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda() flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() dispnet_weights = torch.load(args.pretrained_disp) posenet_weights = torch.load(args.pretrained_pose) masknet_weights = torch.load(args.pretrained_mask) flownet_weights = torch.load(args.pretrained_flow) disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) mask_net.load_state_dict(masknet_weights['state_dict']) disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() error_names = [ 'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask', 'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask' ] errors = AverageMeter(i=len(error_names)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt) in enumerate(tqdm(val_loader)): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True) disp = disp_net(tgt_img_var) depth = 1 / disp pose = pose_net(tgt_img_var, ref_imgs_var) explainability_mask = mask_net(tgt_img_var, ref_imgs_var) if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3]) else: flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var, intrinsics_inv_var) flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var, intrinsics_inv_var) rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * ( 1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).abs() #.normalize() rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * ( rigidity_mask_census_v).type_as(flow_fwd) rigidity_mask_combined = 1 - ( 1 - rigidity_mask.type_as(explainability_mask)) * ( 1 - rigidity_mask_census.type_as(explainability_mask)) obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as( flow_fwd).expand_as(flow_fwd) * flow_fwd flow_fwd_rigid = (rigidity_mask_combined > args.THRESH ).type_as(flow_cam).expand_as(flow_cam) * flow_cam total_flow = flow_fwd_rigid + flow_fwd_non_rigid rigidity_mask = rigidity_mask.type_as(flow_fwd) _epe_errors = compute_all_epes( flow_gt_var, flow_cam, flow_fwd, rigidity_mask_combined) + compute_all_epes( flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded)) errors.update(_epe_errors) tgt_img_np = tgt_img[0].numpy() rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy( ) gt_mask_np = obj_map_gt[0].numpy() if args.output_dir is not None: np.save(image_dir / str(i).zfill(3), tgt_img_np) np.save(gt_dir / str(i).zfill(3), gt_mask_np) np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np) if (args.output_dir is not None) and i % 10 == 0: ind = int(i // 10) output_writer.add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), ind) output_writer.add_image('val Input', tensor2array(tgt_img[0].cpu()), i) output_writer.add_image( 'val Total Flow Output', flow_to_image(tensor2array(total_flow.data[0].cpu())), ind) output_writer.add_image( 'val Rigid Flow Output', flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Non-rigid Flow Output', flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Rigidity Mask', tensor2array(rigidity_mask.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Census', tensor2array(rigidity_mask_census.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Combined', tensor2array(rigidity_mask_combined.data[0].cpu(), max_value=1, colormap='bone'), ind) tgt_img_viz = tensor2array(tgt_img[0].cpu()) depth_viz = tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone') mask_viz = tensor2array( rigidity_mask_census_soft.data[0].prod(dim=0).cpu(), max_value=1, colormap='bone') rigid_flow_viz = flow_to_image(tensor2array( flow_cam.data[0].cpu())) non_rigid_flow_viz = flow_to_image( tensor2array(flow_fwd_non_rigid.data[0].cpu())) total_flow_viz = flow_to_image( tensor2array(total_flow.data[0].cpu())) row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz)) row2_viz = np.hstack( (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz)) row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8')) row2_viz_im = Image.fromarray((row2_viz).astype('uint8')) row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png') row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png') print("Results") print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ". format(*error_names)) print( "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}" .format(*errors.avg))
def train_AE_RtoD(args, model, DtoD_model, criterion_L2, criterion_L1, optimizer, dataset_loader, val_loader, batch_size, n_epochs, lr, logger, train_writer): global n_iter, best_error print("Training for %d epochs..." % n_epochs) num = 0 model_num = 0 data_iter = iter(dataset_loader) depth_fixed, rgb_fixed, _ = next(data_iter) depth_fixed = depth_fixed.cuda() rgb_fixed = rgb_fixed.cuda() predicted_dirs = './' + args.dataset + '_AE_RtoD_predicted_lr000%d_color_uNet_gen2_nogradf' % ( lr * 100000) result_dirs = './' + args.dataset + '_AE_RtoD_feat_result_lr000%d_color_uNet_gen2_nogradf/out' % ( lr * 100000) result_gt_dirs = './' + args.dataset + '_AE_RtoD_feat_result_lr000%d_color_uNet_gen2_nogradf/gt' % ( lr * 100000) save_dir = './' + args.dataset + '_AE_RtoD_trained_model_lr000%d_color_uNet_gen2_nogradf' % ( lr * 100000) if ((args.local_rank + 1) % 4 == 0): if not os.path.exists(predicted_dirs): os.makedirs(predicted_dirs) if not os.path.exists(result_dirs): os.makedirs(result_dirs) if not os.path.exists(result_gt_dirs): os.makedirs(result_gt_dirs) if not os.path.exists(save_dir): os.makedirs(save_dir) H = depth_fixed.shape[2] W = depth_fixed.shape[3] num_sample_list = [16, 64, 64, 64, 64, 64, 16] figsize_x_list = [14, 16, 10, 10, 10, 16, 14] figsize_y_list = [7, 8, 5, 5, 5, 8, 7] if args.dataset != 'NYU': d_range_list = [4, 2, 2, 4, 2, 2, 4] ftmap_height_list = [ H, int(H / 2), int(H / 8), int(H / 16), int(H / 8), int(H / 2), H ] ftmap_width_list = [ W, int(W / 2), int(W / 8), int(W / 16), int(W / 8), int(W / 2), W ] else: d_range_list = [4, 2, 4, 2, 4, 2, 4] ftmap_height_list = [ H, int(H / 2), int(H / 4), int(H / 8), int(H / 16), int(H / 2), H ] ftmap_width_list = [ W, int(W / 2), int(W / 4), int(W / 8), int(W / 16), int(W / 2), W ] test_loss_dir = Path(args.save_path) test_loss_dir_rmse = str(test_loss_dir / 'test_rmse_list.txt') test_loss_dir = str(test_loss_dir / 'test_loss_list.txt') train_loss_dir = Path(args.save_path) train_loss_dir_rmse = str(train_loss_dir / 'train_rmse_list.txt') train_loss_dir = str(train_loss_dir / 'train_loss_list.txt') loss_list = [] rmse_list = [] train_loss_list = [] train_rmse_list = [] num_cnt = 0 train_loss_cnt = 0 if args.dataset == "KITTI": y1, y2 = int(0.40810811 * depth_fixed.size(2)), int( 0.99189189 * depth_fixed.size(2)) x1, x2 = int(0.03594771 * depth_fixed.size(3)), int( 0.96405229 * depth_fixed.size(3)) ### Crop used by Garg ECCV 2016 ''' y1,y2 = int(0.3324324 * depth_fixed.size(2)), int(0.91351351 * depth_fixed.size(2)) x1,x2 = int(0.0359477 * depth_fixed.size(3)), int(0.96405229 * depth_fixed.size(3)) ### Crop used by Godard CVPR 2017 ''' print(" - valid y range: %d ~ %d" % (y1, y2)) print(" - valid x range: %d ~ %d" % (x1, x2)) for epoch in tqdm(range(n_epochs)): if args.dataset == "KITTI": crop_mask = depth_fixed != depth_fixed #print('crop_mask size: ',crop_mask.size()) crop_mask[:, :, y1:y2, x1:x2] = 1 if logger is not None: logger.epoch_bar.update(epoch) ####################################### one epoch training ############################################# logger.reset_train_bar() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) ################ train mode #################### model.train() ################################################ end = time.time() if logger is not None: logger.train_bar.update(0) for i, (gt_data, rgb_data, gt_data_2) in enumerate(dataset_loader): # data loading time if logger is not None: data_time.update(time.time() - end) # get the inputs inputs = rgb_data depths = gt_data if args.dataset != "KITTI": gt_data_2 = None # If gt_data_2 is None ==> NYU dataset! if gt_data_2 is not None: sparse_depths = gt_data_2.cuda() sparse_depths = Variable(sparse_depths) origin = depths inputs = inputs.cuda() depths = depths.cuda() # wrap them in Variable inputs, depths = Variable(inputs), Variable(depths) ######################################## ### Train the AutoEncoder (Generator) ### ######################################## '''AutoEncoder loss''' outputs = model(inputs, istrain=False) if args.mode != 'RtoD_single': with torch.no_grad(): ft_map1_tar, ft_map2_tar, ft_map3_tar, ft_map4_tar, _, _, _, _ = DtoD_model( depths, istrain=True) if args.mode != 'RtoD_single': with torch.no_grad(): ft_map1, ft_map2, ft_map3, ft_map4, _, _, _, _ = DtoD_model( outputs, istrain=True) # masking valied area if gt_data_2 is not None: valid_mask = sparse_depths > -1 valid_mask = valid_mask[:, 0, :, :].unsqueeze(1) if (crop_mask.size(0) != valid_mask.size(0)): crop_mask = crop_mask[0:valid_mask.size(0), :, :, :] diff = outputs - depths diff_abs = torch.abs(diff) diff_2 = torch.pow(outputs - depths, 2) c = 0.2 * torch.max(diff_abs.detach()) mask2 = torch.gt(diff_abs.detach(), c) diff_abs[mask2] = (diff_2[mask2] + (c * c)) / (2 * c) if gt_data_2 is not None: diff_abs[~crop_mask] = 0.1 * diff_abs[~crop_mask] diff_abs[crop_mask & (~valid_mask)] = 0.3 * diff_abs[crop_mask & (~valid_mask)] output_loss = 3 * diff_abs.mean() diff2_clone = diff_2.clone().detach() rmse_loss = torch.sqrt(diff2_clone.mean()) ################# BerHu Loss ######################### latent_loss = torch.tensor(0.).cuda() if args.mode != 'RtoD_single': latent1 = criterion_L2(ft_map1, ft_map1_tar.detach()) latent2 = 2.5 * criterion_L2(ft_map2, ft_map2_tar.detach()) latent3 = 14 * criterion_L2(ft_map3, ft_map3_tar.detach()) latent4 = 12 * criterion_L2(ft_map4, ft_map4_tar.detach()) #print("latent1 : ",latent1.item(),"latent2 : ",latent2.item(),"latent3 : ",latent3.item(),"latent4 : ",latent4.item()) latent_loss = 1.5 * ( (latent1 + latent2 + latent3 + latent4) / 4) ################# Latent Loss ######################### #gradient_loss = imgrad_loss(outputs, depths) ## for kitti ##gradient_loss = 3.5* imgrad_loss(outputs, depths) ## for NYU ################# gradient loss ####################### ''' grad_latent_loss = torch.tensor(0.).cuda() if args.mode != 'RtoD_single': grad_latent1 = imgrad_loss(ft_map1, ft_map1_tar.detach()) grad_latent2 = 1.5*imgrad_loss(ft_map2, ft_map2_tar.detach()) grad_latent3 = 4*imgrad_loss(ft_map3, ft_map3_tar.detach()) ##for kitti ##grad_latent3 = 2*imgrad_loss(ft_map3, ft_map3_tar.detach()) ##for NYU grad_latent4 = 2.3*imgrad_loss(ft_map4, ft_map4_tar.detach()) #print("g_latent1 : ",grad_latent1.item(),"g_latent2 : ",grad_latent2.item(),"g_latent3 : ",grad_latent3.item(),"g_latent4 : ",grad_latent4.item()) ##grad_latent_loss = ((grad_latent1 + grad_latent2 + grad_latent3 + grad_latent4)/4.0) ## for kitti grad_latent_loss = ((grad_latent1 + grad_latent2 + grad_latent3 + grad_latent4)/4.0) ## for NYU ################# gradient latent loss ################ grad_loss = (gradient_loss + grad_latent_loss) ''' ################# gradient total loss ################ depth_smoothness_tot = 0.1 * depth_smoothness(outputs, inputs) depth_smoothness_loss = torch.mean(torch.abs(depth_smoothness_tot)) ################# smoothness loss ###################### #loss = output_loss + latent_loss + grad_loss + depth_smoothness_loss loss = output_loss + latent_loss + depth_smoothness_loss if logger is not None: if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('output_loss', output_loss.item(), n_iter) train_writer.add_scalar('latent_loss', latent_loss.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) # zero the parameter gradients and backward & ptimize optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time if logger is not None: batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( [loss.item(), output_loss.item(), latent_loss.item()]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write( 'Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) n_iter += 1 if i >= args.epoch_size - 1: break ### KITTI's learning decay ### if (epoch > 2): if ((i + 1) % 2200 == 0): if (lr < 0.00002): lr -= (lr / 100) else: lr -= (lr / 60) for param_group in optimizer.param_groups: param_group['lr'] = lr print('Decayed learning rates, lr: {}'.format(lr)) ### NYU's learning decay ### ''' if (epoch>6): if ((i+1) % 1900 == 0): if (lr < 0.00002): lr -= (lr / 200) else : lr -= (lr / 40) for param_group in optimizer.param_groups: param_group['lr'] = lr print ('Decayed learning rates, lr: {}'.format(lr)) ''' if ((i + 1) % 100 == 0): if ((args.local_rank + 1) % 4 == 0): print("epoch: %d, %d/%d" % (epoch + 1, i + 1, args.epoch_size)) if args.mode == 'RtoD': print( "total_loss: %5f, output_loss: %5f, smoothness_loss: %5f, latent_loss: %5f" % (loss.item(), output_loss.item(), depth_smoothness_loss.item(), latent_loss.item())) #print("grad_loss: %5f, gradient_loss: %5f, grad_latent_loss: %5f"%(grad_loss.item(), gradient_loss.item(), grad_latent_loss.item())) elif args.mode == 'RtoD_single': print( "total_loss: %5f, output_loss: %5f, smoothness_loss: %5f" % (loss.item(), output_loss.item(), depth_smoothness_loss.item())) print("grad_loss: %5f" % (grad_loss.item())) ''' total_loss = loss.item() rmse_loss = rmse_loss.item() loss_pdf = "train_loss.pdf" rmse_pdf = "train_rmse.pdf" train_loss_cnt = train_loss_cnt + 1 all_plot(args.save_path,total_loss, rmse_loss, train_loss_list, train_rmse_list, train_loss_dir,train_loss_dir_rmse,loss_pdf, rmse_pdf, train_loss_cnt,True) print("") ''' if ((i + 1) % 700 == 0): save_image_batch(model, rgb_fixed, depth_fixed, predicted_dirs, num) num = num + 1 if ((i + 1) % 700 == 0): ''' test_loss, rmse_test_loss = validate_in_test(args, val_loader, model,DtoD_model,n_epochs, logger,args.mode, crop_mask,criterion_L2) loss_pdf = "test_loss.pdf" rmse_pdf = "test_rmse.pdf" num_cnt = num_cnt + 1 if((args.local_rank + 1)%4 == 0): print('%d th test_set_loss : %.4f'%(num_cnt,test_loss)) all_plot(args.save_path,test_loss, rmse_test_loss, loss_list, rmse_list, test_loss_dir,test_loss_dir_rmse,loss_pdf, rmse_pdf, num_cnt,False) ''' if ((args.local_rank + 1) % 4 == 0): output = outputs.cpu().detach().numpy() save_image_tensor(output, result_dirs, 'output_depth_%d.png' % (model_num + 1)) save_image_tensor(origin, result_gt_dirs, 'origin_depth_%d.png' % (model_num + 1)) torch.save( model.state_dict(), save_dir + '/epoch_%d_AE_depth_loss_%.4f.pkl' % (model_num + 1, loss)) model_num = model_num + 1 if logger is not None: ###################################################################################################### logger.train_writer.write(' * Avg Loss : {:.3f}'.format( losses.avg[0])) ################################ evalutating on validation set ######################################## logger.reset_valid_bar() errors, error_names = validate(args, val_loader, model, epoch, logger, args.mode) ################# training log ############################ error_string = ', '.join( '{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) logger.valid_writer.write(' * Avg {}'.format(error_string)) for error, name in zip(errors, error_names): train_writer.add_scalar(name, error, epoch) # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) decisive_error = errors[1] if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint is_best = decisive_error < best_error best_error = min(best_error, decisive_error) if is_best: torch.save(model, args.save_path / 'AE_RtoD_model_best.pth.tar') with open(args.save_path / args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss, decisive_error]) ########################################################### ##if ((epoch+1) % 2) and ((epoch+1) > n_epochs/2): #if (((epoch+1) % 2) and epoch>5): ''' if epoch % 1 == 0: #print('\n','epoch: ',epoch+1,' loss: ',loss.item()) print('output_loss: ',output_loss.item(),' latent_loss: ',latent_loss.item()) output = outputs.cpu().detach().numpy() save_image_tensor(output,result_dirs,'output_depth_%d.png'%(model_num+1)) save_image_tensor(origin,result_gt_dirs,'origin_depth_%d.png'%(model_num+1)) torch.save(model.state_dict(), save_dir+'/epoch_%d_AE_depth_loss_%.4f.pkl' %(model_num+1,loss)) model_num = model_num + 1 ''' ##################################################################################### ################### Extracting feature_map ########################################## if ((epoch + 1) % 10 == 0 or epoch == 0): if ((args.local_rank + 1) % 4 == 0): with torch.no_grad(): rft1, rft2, rft3, rft4, rft5, rft6, rft7, rout = model( inputs, istrain=True) ft1_gt, ft2_gt, ft3_gt, ft4_gt, ft5_gt, ft6_gt, ft7_gt, _ = DtoD_model( depths, istrain=True) dft1, dft2, dft3, dft4, dft5, dft6, dft7, _ = DtoD_model( rout, istrain=True) rftmap_list = [rft1, rft2, rft3, rft4, rft5, rft6, rft7] gt_ftmap_list = [ ft1_gt, ft2_gt, ft3_gt, ft4_gt, ft5_gt, ft6_gt, ft7_gt ] dftmap_list = [dft1, dft2, dft3, dft4, dft5, dft6, dft7] result_dir = result_dirs + '/epoch_%d_depth' % (epoch + 1) if not os.path.exists(result_dir): os.makedirs(result_dir) for kk in range(len(rftmap_list)): ftmap_extract(args, num_sample_list[kk], figsize_x_list[kk], figsize_y_list[kk], d_range_list[kk], rftmap_list[kk], ftmap_height_list[kk], ftmap_width_list[kk], result_dir + '/RtoD', epoch, kk + 1) ftmap_extract(args, num_sample_list[kk], figsize_x_list[kk], figsize_y_list[kk], d_range_list[kk], gt_ftmap_list[kk], ftmap_height_list[kk], ftmap_width_list[kk], result_dir + '/DtoD_gt', epoch, kk + 1) ftmap_extract(args, num_sample_list[kk], figsize_x_list[kk], figsize_y_list[kk], d_range_list[kk], dftmap_list[kk], ftmap_height_list[kk], ftmap_width_list[kk], result_dir + '/DtoD', epoch, kk + 1) print("featmap save is finished") inputs_ = inputs.cpu().detach().numpy() save_image_tensor(origin, result_dir, 'origin_depth.png') save_image_tensor(inputs_, result_dir, 'origin_input.png') print("origin_depth save is finished") print("origin_image save is finished") ##################################################################################### ##################################################################################### if logger is not None: logger.epoch_bar.finish() return loss, output_loss, latent_loss
def adjust_shifts(args, train_set, adjust_loader, pose_exp_net, epoch, logger, train_writer): batch_time = AverageMeter() data_time = AverageMeter() new_shifts = AverageMeter(args.sequence_length-1) pose_exp_net.train() poses = np.zeros(((len(adjust_loader)-1) * args.batch_size * (args.sequence_length-1),6)) end = time.time() for i, (indices, tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(adjust_loader): # measure data loading time data_time.update(time.time() - end) tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] # compute output explainability_mask, pose_batch = pose_exp_net(tgt_img_var, ref_imgs_var) if i < len(adjust_loader)-1: step = args.batch_size*(args.sequence_length-1) poses[i * step:(i+1) * step] = pose_batch.data.cpu().view(-1,6).numpy() for index, pose in zip(indices, pose_batch): displacements = pose[:,:3].norm(p=2, dim=1).data.cpu().numpy() train_set.reset_shifts(index, displacements) new_shifts.update(train_set.samples[index]['ref_imgs']) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.train_bar.update(i) if i % args.print_freq == 0: logger.train_writer.write('Adjustement:' 'Time {} Data {} shifts {}'.format(batch_time, data_time, new_shifts)) prefix = 'train poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): train_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:,i], epoch) return new_shifts.avg
def validate(val_loader, model, log, fgsm=False, eps=4, rand_init=False, mean=None, std=None): '''evaluate trained model''' losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() for i, (input, target) in enumerate(val_loader): if args.use_cuda: input = input.cuda() target = target.cuda() # check FGSM for adversarial training if fgsm: input_var = Variable(input, requires_grad=True) target_var = Variable(target) optimizer_input = torch.optim.SGD([input_var], lr=0.1) output = model(input_var) loss = criterion(output, target_var) optimizer_input.zero_grad() loss.backward() sign_data_grad = input_var.grad.sign() input = input * std + mean + eps / 255. * sign_data_grad input = torch.clamp(input, 0, 1) input = (input - mean) / std with torch.no_grad(): input_var = Variable(input) target_var = Variable(target) # compute output output = model(input_var) loss = criterion(output, target_var) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) if fgsm: print_log( 'Attack (eps : {}) Prec@1 {top1.avg:.2f}'.format(eps, top1=top1), log) else: print_log( ' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f} Loss: {losses.avg:.3f} ' .format(top1=top1, top5=top5, error1=100 - top1.avg, losses=losses), log) return top1.avg, losses.avg
def train(args, train_loader, disp_net, pose_net, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.geometry_consistency_weight # switch to train mode disp_net.train() pose_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output tgt_depth, ref_depths = compute_depth(disp_net, tgt_img, ref_imgs) poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs, intrinsics) #if poses is None: loss_1, loss_3 = compute_photo_and_geometry_loss( tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses, poses_inv, args.num_scales, args.with_ssim, args.with_mask, args.with_auto_mask, args.padding_mode) loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if log_losses: train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_2.item(), n_iter) train_writer.add_scalar('geometry_consistency_loss', loss_3.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( [loss.item(), loss_1.item(), loss_2.item(), loss_3.item()]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def train(args, train_loader, mvdnet, depth_cons, cons_loss_, optimizer, epoch_size, train_writer, epoch): global n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) d_losses = AverageMeter(precision=4) nmap_losses = AverageMeter(precision=4) cons_losses = AverageMeter(precision=4) # switch to training mode if args.train_cons: depth_cons.train() else: mvdnet.train() print("Training") end = time.time() for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv, tgt_depth) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] gt_nmap_var = Variable(gt_nmap.cuda()) ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) tgt_depth_var = Variable(tgt_depth.cuda()).cuda() # compute output pose = torch.cat(ref_poses_var, 1) # get mask mask = (tgt_depth_var <= args.nlabel * args.mindepth) & ( tgt_depth_var >= args.mindepth) & (tgt_depth_var == tgt_depth_var) mask.detach_() if mask.any() == 0: continue if args.train_cons: with torch.no_grad(): outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var) output_depth1 = outputs[0] nmap1 = outputs[1] else: outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var) output_depth1 = outputs[1] nmap1 = outputs[2] if args.train_cons: outputs = depth_cons(output_depth1, nmap1) nmap = outputs[:, 1:] depths = [outputs[:, 0]] else: nmap = nmap1.permute(0, 3, 1, 2) depths = [output_depth1.squeeze(1)] loss = 0. d_loss = 0. nmap_loss = 0. cons_loss = 0. for l, depth in enumerate(depths): output = torch.squeeze(depth, 1) d_loss = d_loss + F.smooth_l1_loss(output[mask], tgt_depth_var[mask]) n_mask = mask.unsqueeze(1).expand(-1, 3, -1, -1) nmap_loss = nmap_loss + F.smooth_l1_loss(nmap[n_mask], gt_nmap_var[n_mask]) if args.train_cons: cons_loss = cons_loss + cons_loss_( depths[-1].unsqueeze(1), tgt_depth_var.unsqueeze(1), nmap.clone(), intrinsics_var, mask.unsqueeze(1)) cons_losses.update(cons_loss.item(), args.batch_size) loss = loss + args.d_weight * d_loss + args.n_weight * nmap_loss + args.c_weight * cons_loss if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('total_loss', loss.item(), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) d_losses.update(d_loss.item(), args.batch_size) nmap_losses.update(nmap_loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss.item()]) if i % args.print_freq == 0: print( 'Train: Time {} Data {} Loss {} NmapLoss {} DLoss {} ConsLoss {}Iter {}/{} Epoch {}/{}' .format(batch_time, data_time, losses, nmap_losses, d_losses, cons_losses, i, len(train_loader), epoch, args.epochs)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger, output_writers=[]): global device batch_time = AverageMeter() losses = AverageMeter(i=4, precision=4) log_outputs = len(output_writers) > 0 # switch to evaluate mode disp_net.eval() pose_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) # compute output tgt_depth = [1 / disp_net(tgt_img)] ref_depths = [] for ref_img in ref_imgs: ref_depth = [1 / disp_net(ref_img)] ref_depths.append(ref_depth) if log_outputs and i < len(output_writers): if epoch == 0: output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0) output_writers[i].add_image( 'val Dispnet Output Normalized', tensor2array(1 / tgt_depth[0][0], max_value=None, colormap='magma'), epoch) output_writers[i].add_image( 'val Depth Output', tensor2array(tgt_depth[0][0], max_value=10), epoch) poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs, intrinsics) loss_1, loss_3 = compute_photo_and_geometry_loss( tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses, poses_inv, args.num_scales, args.with_ssim, args.with_mask, False, args.padding_mode) loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs) loss_1 = loss_1.item() loss_2 = loss_2.item() loss_3 = loss_3.item() loss = loss_1 losses.update([loss, loss_1, loss_2, loss_3]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) logger.valid_bar.update(len(val_loader)) return losses.avg, [ 'Total loss', 'Photo loss', 'Smooth loss', 'Consistency loss' ]
def main(): global global_vars_dict args = global_vars_dict['args'] best_error = -1 #best model choosing #mkdir timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp args.save_path.makedirs_p() torch.manual_seed(args.seed) if args.alternating: args.alternating_flags = np.array([False, False, True]) #mk writers tb_writer = SummaryWriter(args.save_path) # Data loading code and transpose if args.data_normalization == 'global': normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif args.data_normalization == 'local': normalize = custom_transforms.NormalizeLocally() valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data_dir)) train_transform = custom_transforms.Compose([ #custom_transforms.RandomRotate(), custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) #train set, loader only建立一个 from datasets.sequence_mc import SequenceFolder train_set = SequenceFolder( # mc data folder args.data_dir, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length, # 5 target_transform=None, depth_format='png') train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) #val set,loader 挨个建立 #if args.val_with_depth_gt: from datasets.validation_folders2 import ValidationSet val_set_with_depth_gt = ValidationSet(args.data_dir, transform=valid_transform, depth_format='png') val_loader_depth = torch.utils.data.DataLoader(val_set_with_depth_gt, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) #1 create model print("=> creating model") #1.1 disp_net disp_net = getattr(models, args.dispnet)().cuda() output_exp = True #args.mask_loss_weight > 0 if args.pretrained_disp: print("=> using pre-trained weights from {}".format( args.pretrained_disp)) weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() if args.resume: print("=> resuming from checkpoint") dispnet_weights = torch.load(args.save_path / 'dispnet_checkpoint.pth.tar') disp_net.load_state_dict(dispnet_weights['state_dict']) cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) print('=> setting adam solver') parameters = chain(disp_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) if args.resume and (args.save_path / 'optimizer_checkpoint.pth.tar').exists(): print("=> loading optimizer from checkpoint") optimizer_weights = torch.load(args.save_path / 'optimizer_checkpoint.pth.tar') optimizer.load_state_dict(optimizer_weights['state_dict']) # if args.log_terminal: logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader_depth)) logger.reset_epoch_bar() else: logger = None #预先评估下 criterion_train = MaskedL1Loss().to(device) # l1LOSS 容易优化 criterion_val = ComputeErrors().to(device) #depth_error_names,depth_errors = validate_depth_with_gt(val_loader_depth, disp_net,criterion=criterion_val, epoch=0, logger=logger,tb_writer=tb_writer,global_vars_dict=global_vars_dict) #logger.reset_epoch_bar() # logger.epoch_logger_update(epoch=0,time=0,names=depth_error_names,values=depth_errors) epoch_time = AverageMeter() end = time.time() #3. main cycle for epoch in range(1, args.epochs): #epoch 0 在第没入循环之前已经测试了. logger.reset_train_bar() logger.reset_valid_bar() errors = [0] error_names = ['no error names depth'] #3.2 train for one epoch--------- loss_names, losses = train_depth_gt(train_loader=train_loader, disp_net=disp_net, optimizer=optimizer, criterion=criterion_train, logger=logger, train_writer=tb_writer, global_vars_dict=global_vars_dict) #3.3 evaluate on validation set----- depth_error_names, depth_errors = validate_depth_with_gt( val_loader=val_loader_depth, disp_net=disp_net, criterion=criterion_val, epoch=epoch, logger=logger, tb_writer=tb_writer, global_vars_dict=global_vars_dict) epoch_time.update(time.time() - end) end = time.time() #3.5 log_terminal #if args.log_terminal: if args.log_terminal: logger.epoch_logger_update(epoch=epoch, time=epoch_time, names=depth_error_names, values=depth_errors) # tensorboard scaler #train loss for loss_name, loss in zip(loss_names, losses.avg): tb_writer.add_scalar('train/' + loss_name, loss, epoch) #val_with_gt loss for name, error in zip(depth_error_names, depth_errors.avg): tb_writer.add_scalar('val/' + name, error, epoch) #3.6 save model and remember lowest error and save checkpoint total_loss = losses.avg[0] if best_error < 0: best_error = total_loss is_best = total_loss <= best_error best_error = min(best_error, total_loss) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': None }, { 'epoch': epoch + 1, 'state_dict': None }, { 'epoch': epoch + 1, 'state_dict': None }, is_best) if args.log_terminal: logger.epoch_bar.finish()
def train(args, train_loader, bio_net, optimizer, epoch_size, logger): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) # switch to train mode bio_net.train() end = time.time() logger.train_bar.update(0) for i, (sample, value) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) data = sample.to(device) # compute output estimated_y = bio_net(data) estimated_y = estimated_y.view(-1) value = value.float() value = value.to(device) loss = value - estimated_y # print("value is:", value.size()) # print("estimated_y:", estimated_y.size()) # record loss and EPE # print("loss", loss.size()) # print("loss.item:", loss) loss_sum = torch.sum(loss.data) loss_sum = Variable(loss_sum, requires_grad=True) # print("loss.item:", loss_sum, loss_sum.item()) # print("args.batch_size", args.batch_size) losses.update(loss_sum.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss_sum.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output disparities = disp_net(tgt_img) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = inverse_depth_smooth_loss(depth, tgt_img) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if log_losses: train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(train_writer, "train", k, n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate(val_loader, disp_net, pose_exp_net, epoch, logger, output_writers=[]): global args batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = len(output_writers) > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) # compute output disp = disp_net(tgt_img_var) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var) loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth, explainability_mask, pose) loss_1 = loss_1.data[0] if w2 > 0: loss_2 = explainability_loss(explainability_mask).data[0] else: loss_2 = 0 loss_3 = smooth_loss(disp).data[0] if log_outputs and i % 100 == 0 and i / 100 < len( output_writers): # log first output of every 100 batch index = int(i // 100) if epoch == 0: for j, ref in enumerate(ref_imgs): output_writers[index].add_image('val Input {}'.format(j), tensor2array(tgt_img[0]), 0) output_writers[index].add_image('val Input {}'.format(j), tensor2array(ref[0]), 1) output_writers[index].add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), epoch) output_writers[index].add_image( 'val Depth Output', tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_var): ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j], intrinsics_var[:1], intrinsics_inv_var[:1])[0] output_writers[index].add_image( 'val Warped Outputs {}'.format(j), tensor2array(ref_warped.data.cpu()), epoch) output_writers[index].add_image( 'val Diff Outputs {}'.format(j), tensor2array( 0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()), epoch) if explainability_mask is not None: output_writers[index].add_image( 'val Exp mask Outputs {}'.format(j), tensor2array(explainability_mask[0, j].data.cpu(), max_value=1, colormap='bone'), epoch) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1, 6).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch) output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch) output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch) output_writers[0].add_histogram('val poses_rx', poses[:, 3], epoch) output_writers[0].add_histogram('val poses_ry', poses[:, 4], epoch) output_writers[0].add_histogram('val poses_rz', poses[:, 5], epoch) return losses.avg
def validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers=[]): global device batch_time = AverageMeter() error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] errors = AverageMeter(i=len(error_names)) log_outputs = len(output_writers) > 0 # switch to evaluate mode disp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, depth) in enumerate(val_loader): tgt_img = tgt_img.to(device) depth = depth.to(device) # compute output output_disp = disp_net(tgt_img) output_depth = 1 / output_disp[:, 0] if log_outputs and i < len(output_writers): if epoch == 0: output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0) depth_to_show = depth[0] output_writers[i].add_image( 'val target Depth', tensor2array(depth_to_show, max_value=10), epoch) depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1 / depth_to_show).clamp(0, 10) output_writers[i].add_image( 'val target Disparity Normalized', tensor2array(disp_to_show, max_value=None, colormap='magma'), epoch) output_writers[i].add_image( 'val Dispnet Output Normalized', tensor2array(output_disp[0], max_value=None, colormap='magma'), epoch) output_writers[i].add_image( 'val Depth Output', tensor2array(output_depth[0], max_value=10), epoch) errors.update(compute_errors(depth, output_depth)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} Abs Error {:.4f} ({:.4f})'.format( batch_time, errors.val[0], errors.avg[0])) logger.valid_bar.update(len(val_loader)) return errors.avg, error_names
def main(): global args args = parser.parse_args() args.pretrained_disp = Path(args.pretrained_disp) args.pretrained_pose = Path(args.pretrained_pose) args.pretrained_mask = Path(args.pretrained_mask) args.pretrained_flow = Path(args.pretrained_flow) if args.output_dir is not None: args.output_dir = Path(args.output_dir) args.output_dir.makedirs_p() image_dir = args.output_dir / 'images' gt_dir = args.output_dir / 'gt' mask_dir = args.output_dir / 'mask' viz_dir = args.output_dir / 'viz' rigidity_mask_dir = args.output_dir / 'rigidity' rigidity_census_mask_dir = args.output_dir / 'rigidity_census' explainability_mask_dir = args.output_dir / 'explainability' image_dir.makedirs_p() gt_dir.makedirs_p() mask_dir.makedirs_p() viz_dir.makedirs_p() rigidity_mask_dir.makedirs_p() rigidity_census_mask_dir.makedirs_p() explainability_mask_dir.makedirs_p() output_writer = SummaryWriter(args.output_dir) normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) flow_loader_h, flow_loader_w = 256, 832 valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) val_flow_set = ValidationMask(root=args.kitti_dir, sequence_length=5, transform=valid_flow_transform) val_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) disp_net = getattr(models, args.dispnet)().cuda() pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda() mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda() flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() dispnet_weights = torch.load(args.pretrained_disp) posenet_weights = torch.load(args.pretrained_pose) masknet_weights = torch.load(args.pretrained_mask) flownet_weights = torch.load(args.pretrained_flow) disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) mask_net.load_state_dict(masknet_weights['state_dict']) disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() error_names = ['tp_0', 'fp_0', 'fn_0', 'tp_1', 'fp_1', 'fn_1'] errors = AverageMeter(i=len(error_names)) errors_census = AverageMeter(i=len(error_names)) errors_bare = AverageMeter(i=len(error_names)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt, semantic_map_gt) in enumerate(tqdm(val_loader)): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True) disp = disp_net(tgt_img_var) depth = 1 / disp pose = pose_net(tgt_img_var, ref_imgs_var) explainability_mask = mask_net(tgt_img_var, ref_imgs_var) if args.flownet in ['Back2Future']: flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3]) else: flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var, intrinsics_inv_var) rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * ( 1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).pow(2).sum( dim=1).unsqueeze(1).sqrt() #.normalize() rigidity_mask_census_soft = 1 - rigidity_mask_census_soft / rigidity_mask_census_soft.max( ) rigidity_mask_census = rigidity_mask_census_soft > args.THRESH rigidity_mask_combined = 1 - ( 1 - rigidity_mask.type_as(explainability_mask)) * ( 1 - rigidity_mask_census.type_as(explainability_mask)) flow_fwd_non_rigid = (1 - rigidity_mask_combined).type_as( flow_fwd).expand_as(flow_fwd) * flow_fwd flow_fwd_rigid = rigidity_mask_combined.type_as(flow_fwd).expand_as( flow_fwd) * flow_cam total_flow = flow_fwd_rigid + flow_fwd_non_rigid obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) tgt_img_np = tgt_img[0].numpy() rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy( ) rigidity_mask_census_np = rigidity_mask_census.cpu().data[0].numpy() rigidity_mask_bare_np = rigidity_mask.cpu().data[0].numpy() gt_mask_np = obj_map_gt[0].numpy() semantic_map_np = semantic_map_gt[0].numpy() _errors = mask_error(gt_mask_np, semantic_map_np, rigidity_mask_combined_np[0]) _errors_census = mask_error(gt_mask_np, semantic_map_np, rigidity_mask_census_np[0]) _errors_bare = mask_error(gt_mask_np, semantic_map_np, rigidity_mask_bare_np[0]) errors.update(_errors) errors_census.update(_errors_census) errors_bare.update(_errors_bare) if args.output_dir is not None: np.save(image_dir / str(i).zfill(3), tgt_img_np) np.save(gt_dir / str(i).zfill(3), gt_mask_np) np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np) np.save(rigidity_mask_dir / str(i).zfill(3), rigidity_mask.cpu().data[0].numpy()) np.save(rigidity_census_mask_dir / str(i).zfill(3), rigidity_mask_census.cpu().data[0].numpy()) np.save(explainability_mask_dir / str(i).zfill(3), explainability_mask[:, 1].cpu().data[0].numpy()) # rigidity_mask_dir rigidity_mask.numpy() # rigidity_census_mask_dir rigidity_mask_census.numpy() if (args.output_dir is not None) and i % 10 == 0: ind = int(i // 10) output_writer.add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), ind) output_writer.add_image('val Input', tensor2array(tgt_img[0].cpu()), i) output_writer.add_image( 'val Total Flow Output', flow_to_image(tensor2array(total_flow.data[0].cpu())), ind) output_writer.add_image( 'val Rigid Flow Output', flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Non-rigid Flow Output', flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Rigidity Mask', tensor2array(rigidity_mask.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Census', tensor2array(rigidity_mask_census.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Combined', tensor2array(rigidity_mask_combined.data[0].cpu(), max_value=1, colormap='bone'), ind) if args.output_dir is not None: tgt_img_viz = tensor2array(tgt_img[0].cpu()) depth_viz = tensor2array(disp.data[0].cpu(), max_value=None, colormap='magma') mask_viz = tensor2array(rigidity_mask_census_soft.data[0].cpu(), max_value=1, colormap='bone') row2_viz = flow_to_image( np.hstack((tensor2array(flow_cam.data[0].cpu()), tensor2array(flow_fwd_non_rigid.data[0].cpu()), tensor2array(total_flow.data[0].cpu())))) row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz)) ####### sửa 2 cái vstack thành hstack ############### viz3 = np.hstack( (255 * tgt_img_viz, 255 * depth_viz, 255 * mask_viz, flow_to_image( np.hstack((tensor2array(flow_fwd_non_rigid.data[0].cpu()), tensor2array(total_flow.data[0].cpu())))))) ######################################################## ######## code tự thêm #################### row1_viz = np.transpose(row1_viz, (1, 2, 0)) row2_viz = np.transpose(row2_viz, (1, 2, 0)) viz3 = np.transpose(viz3, (1, 2, 0)) ########################################## row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8')) row2_viz_im = Image.fromarray((row2_viz).astype('uint8')) viz3_im = Image.fromarray(viz3.astype('uint8')) row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png') row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png') viz3_im.save(viz_dir / str(i).zfill(3) + '03.png') bg_iou = errors.sum[0] / (errors.sum[0] + errors.sum[1] + errors.sum[2]) fg_iou = errors.sum[3] / (errors.sum[3] + errors.sum[4] + errors.sum[5]) avg_iou = (bg_iou + fg_iou) / 2 bg_iou_census = errors_census.sum[0] / ( errors_census.sum[0] + errors_census.sum[1] + errors_census.sum[2]) fg_iou_census = errors_census.sum[3] / ( errors_census.sum[3] + errors_census.sum[4] + errors_census.sum[5]) avg_iou_census = (bg_iou_census + fg_iou_census) / 2 bg_iou_bare = errors_bare.sum[0] / ( errors_bare.sum[0] + errors_bare.sum[1] + errors_bare.sum[2]) fg_iou_bare = errors_bare.sum[3] / ( errors_bare.sum[3] + errors_bare.sum[4] + errors_bare.sum[5]) avg_iou_bare = (bg_iou_bare + fg_iou_bare) / 2 print("Results Full Model") print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou')) print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format( avg_iou, bg_iou, fg_iou)) print("Results Census only") print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou')) print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format( avg_iou_census, bg_iou_census, fg_iou_census)) print("Results Bare") print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou')) print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format( avg_iou_bare, bg_iou_bare, fg_iou_bare))
def train(args, train_loader, disp_net, pose_net, optimizer, epoch_size, logger, tb_writer, w1, w3): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) # Set the networks to training mode, batch norm and dropout are handled accordingly disp_net.train() pose_net.train() end = time.time() logger.train_bar.start() logger.train_bar.update(0) for i, trainingdata in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_lf = trainingdata['tgt_lf'].to(device) ref_lfs = [img.to(device) for img in trainingdata['ref_lfs']] if args.lfformat == "epi" and args.cameras_epi == "full": # in this case we have separate horizontal and vertical epis tgt_lf_formatted_h = trainingdata['tgt_lf_formatted_h'].to(device) tgt_lf_formatted_v = trainingdata['tgt_lf_formatted_v'].to(device) ref_lfs_formatted_h = [lf.to(device) for lf in trainingdata['ref_lfs_formatted_h']] ref_lfs_formatted_v = [lf.to(device) for lf in trainingdata['ref_lfs_formatted_v']] # stacked images tgt_stack = trainingdata['tgt_stack'].to(device) ref_stacks = [lf.to(device) for lf in trainingdata['ref_stacks']] # Encode the epi images further if args.without_disp_stack: # Stacked images should not be concatenated with the encoded EPI images tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted_v, None, tgt_lf_formatted_h) else: # Stacked images should be concatenated with the encoded EPI images tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted_v, tgt_stack, tgt_lf_formatted_h) tgt_lf_encoded_p, ref_lfs_encoded_p = pose_net.encode(tgt_lf_formatted_v, tgt_stack, ref_lfs_formatted_v, ref_stacks, tgt_lf_formatted_h, ref_lfs_formatted_h) else: tgt_lf_formatted = trainingdata['tgt_lf_formatted'].to(device) ref_lfs_formatted = [lf.to(device) for lf in trainingdata['ref_lfs_formatted']] # Encode the images if necessary if disp_net.has_encoder(): # This will only be called for epi with horizontal or vertical only encoding if args.without_disp_stack: # Stacked images should not be concatenated with the encoded EPI images tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted, None) else: # Stacked images should be concatenated with the encoded EPI images # NOTE: Here we stack all 17 images, not 5. Here the images missing from the encoding, # are covered in the stack. We are not using this case in the paper at all. tgt_lf_encoded_d = disp_net.encode(tgt_lf_formatted, tgt_lf) else: # This will be called for focal stack and stack, where there is no encoding tgt_lf_encoded_d = tgt_lf_formatted if pose_net.has_encoder(): tgt_lf_encoded_p, ref_lfs_encoded_p = pose_net.encode(tgt_lf_formatted, tgt_lf, ref_lfs_formatted, ref_lfs) else: tgt_lf_encoded_p = tgt_lf_formatted ref_lfs_encoded_p = ref_lfs_formatted # compute output of networks disparities = disp_net(tgt_lf_encoded_d) depth = [1/disp for disp in disparities] pose = pose_net(tgt_lf_encoded_p, ref_lfs_encoded_p) # if i==0: # tb_writer.add_graph(disp_net, tgt_lf_encoded_d) # tb_writer.add_graph(pose_net, (tgt_lf_encoded_p, ref_lfs_encoded_p)) # compute photometric error intrinsics = trainingdata['intrinsics'].to(device) pose_gt_tgt_refs = trainingdata['pose_gt_tgt_refs'].to(device) metadata = trainingdata['metadata'] photometric_error, warped, diff = multiwarp_photometric_loss( tgt_lf, ref_lfs, intrinsics, depth, pose, metadata, args.rotation_mode, args.padding_mode ) # smoothness_error = smooth_loss(depth) # smoothness error smoothness_error = total_variation_loss(depth, sum_or_mean="mean") # total variation error # smoothness_error = total_variation_squared_loss(depth) # total variation error squared version mean_distance_error, mean_angle_error = pose_loss(pose, pose_gt_tgt_refs) loss = w1 + torch.exp(-1.0 * w1) * photometric_error + w3 + torch.exp(-1.0 * w3) * smoothness_error if log_losses: tb_writer.add_scalar(tag='train/photometric_error', scalar_value=photometric_error.item(), global_step=n_iter) tb_writer.add_scalar(tag='train/smoothness_loss', scalar_value=smoothness_error.item(), global_step=n_iter) tb_writer.add_scalar(tag='train/total_loss', scalar_value=loss.item(), global_step=n_iter) tb_writer.add_scalar(tag='train/mean_distance_error', scalar_value=mean_distance_error.item(), global_step=n_iter) tb_writer.add_scalar(tag='train/mean_angle_error', scalar_value=mean_angle_error.item(), global_step=n_iter) if log_output: if args.lfformat == "epi" and args.cameras_epi == "full": b, n, h, w = tgt_lf_formatted_v.shape vis_img = tgt_lf_formatted_v[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5 else: b, n, h, w = tgt_lf_formatted.shape vis_img = tgt_lf_formatted[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5 b, n, h, w = depth[0].shape vis_depth = tensor2array(depth[0][0, 0, :, :], colormap='magma') vis_disp = tensor2array(disparities[0][0, 0, :, :], colormap='magma') vis_enc_f = tgt_lf_encoded_d[0, 0, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5 vis_enc_b = tgt_lf_encoded_d[0, -1, :, :].detach().cpu().numpy().reshape(1, h, w) * 0.5 + 0.5 tb_writer.add_image('train/input', vis_img, n_iter) tb_writer.add_image('train/encoded_front', vis_enc_f, n_iter) tb_writer.add_image('train/encoded_back', vis_enc_b, n_iter) tb_writer.add_image('train/depth', vis_depth, n_iter) tb_writer.add_image('train/disp', vis_disp, n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path + "/" + args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss.item(), photometric_error.item(), smoothness_error.item(), mean_distance_error.item(), mean_angle_error.item()]) logger.train_bar.update(i+1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 logger.train_bar.finish() return losses.avg[0]
def validate_with_gt(args, val_loader, depth_net, pose_net, epoch, logger, output_writers=[], **env): global device batch_time = AverageMeter() depth_error_names = ['abs diff', 'abs rel', 'sq rel', 'a1', 'a2', 'a3'] stab_depth_errors = AverageMeter(i=len(depth_error_names)) unstab_depth_errors = AverageMeter(i=len(depth_error_names)) pose_error_names = ['Absolute Trajectory Error', 'Rotation Error'] pose_errors = AverageMeter(i=len(pose_error_names)) # switch to evaluate mode depth_net.eval() pose_net.eval() end = time.time() logger.valid_bar.update(0) for i, sample in enumerate(val_loader): log_output = i < len(output_writers) imgs = torch.stack(sample['imgs'], dim=1).to(device) batch_size, seq, c, h, w = imgs.size() intrinsics = sample['intrinsics'].to(device) intrinsics_inv = sample['intrinsics_inv'].to(device) if args.network_input_size is not None: imgs = F.interpolate(imgs, (c, *args.network_input_size), mode='area') downscale = h / args.network_input_size[0] intrinsics = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) intrinsics_inv = torch.cat( (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :, 2:]), dim=2) GT_depth = sample['depth'].to(device) GT_pose = sample['pose'].to(device) mid_index = (args.sequence_length - 1) // 2 tgt_img = imgs[:, mid_index] if epoch == 1 and log_output: for j, img in enumerate(sample['imgs']): output_writers[i].add_image('val Input', tensor2array(img[0]), j) depth_to_show = GT_depth[0].cpu() # KITTI Like data routine to discard invalid data depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1 / depth_to_show).clamp(0, 10) output_writers[i].add_image( 'val target Disparity Normalized', tensor2array(disp_to_show, max_value=None, colormap='bone'), epoch) poses = pose_net(imgs) pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] inverted_pose_matrices = invert_mat(pose_matrices) pose_errors.update( compute_pose_error(GT_pose[:, :-1], inverted_pose_matrices.data[:, :-1])) tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] compensated_predicted_poses = compensate_pose(pose_matrices, tgt_poses) compensated_GT_poses = compensate_pose(GT_pose, GT_pose[:, mid_index]) for j in range(args.sequence_length): if j == mid_index: if log_output and epoch == 1: output_writers[i].add_image( 'val Input Stabilized', tensor2array(sample['imgs'][j][0]), j) continue '''compute displacement magnitude for each element of batch, and rescale depth accordingly.''' prior_img = imgs[:, j] displacement = compensated_GT_poses[:, j, :, -1] # [B,3] displacement_magnitude = displacement.norm(p=2, dim=1) # [B] current_GT_depth = GT_depth * args.nominal_displacement / displacement_magnitude.view( -1, 1, 1) prior_predicted_pose = compensated_predicted_poses[:, j] # [B, 3, 4] prior_GT_pose = compensated_GT_poses[:, j] prior_predicted_rot = prior_predicted_pose[:, :, :-1] prior_GT_rot = prior_GT_pose[:, :, :-1].transpose(1, 2) prior_compensated_from_GT = inverse_rotate(prior_img, prior_GT_rot, intrinsics, intrinsics_inv) if log_output and epoch == 1: depth_to_show = current_GT_depth[0] output_writers[i].add_image( 'val target Depth {}'.format(j), tensor2array(depth_to_show, max_value=args.max_depth), epoch) output_writers[i].add_image( 'val Input Stabilized', tensor2array(prior_compensated_from_GT[0]), j) prior_compensated_from_prediction = inverse_rotate( prior_img, prior_predicted_rot, intrinsics, intrinsics_inv) predicted_input_pair = torch.cat( [prior_compensated_from_prediction, tgt_img], dim=1) # [B, 6, W, H] GT_input_pair = torch.cat([prior_compensated_from_GT, tgt_img], dim=1) # [B, 6, W, H] # This is the depth from footage stabilized with GT pose, it should be better than depth from raw footage without any GT info raw_depth_stab = depth_net(GT_input_pair) raw_depth_unstab = depth_net(predicted_input_pair) # Upsample depth so that it matches GT size scale_factor = GT_depth.size(-1) // raw_depth_stab.size(-1) depth_stab = F.interpolate(raw_depth_stab, scale_factor=scale_factor, mode='bilinear', align_corners=False) depth_unstab = F.interpolate(raw_depth_unstab, scale_factor=scale_factor, mode='bilinear', align_corners=False) for k, depth in enumerate([depth_stab, depth_unstab]): disparity = 1 / depth errors = stab_depth_errors if k == 0 else unstab_depth_errors errors.update( compute_depth_errors(current_GT_depth, depth, crop=True)) if log_output: prefix = 'stabilized' if k == 0 else 'unstabilized' output_writers[i].add_image( 'val {} Dispnet Output Normalized {}'.format( prefix, j), tensor2array(disparity[0], max_value=None, colormap='bone'), epoch) output_writers[i].add_image( 'val {} Depth Output {}'.format(prefix, j), tensor2array(depth[0], max_value=args.max_depth), epoch) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} ATE Error {:.4f} ({:.4f}), Unstab Rel Abs Error {:.4f} ({:.4f})' .format(batch_time, pose_errors.val[0], pose_errors.avg[0], unstab_depth_errors.val[1], unstab_depth_errors.avg[1])) logger.valid_bar.update(len(val_loader)) errors = (*pose_errors.avg, *unstab_depth_errors.avg, *stab_depth_errors.avg) error_names = (*pose_error_names, *['unstab {}'.format(e) for e in depth_error_names], *['stab {}'.format(e) for e in depth_error_names]) return OrderedDict(zip(error_names, errors))