def train(args, train_loader, mvdnet, depth_cons, cons_loss_, optimizer, epoch_size, train_writer, epoch): global n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) d_losses = AverageMeter(precision=4) nmap_losses = AverageMeter(precision=4) cons_losses = AverageMeter(precision=4) # switch to training mode if args.train_cons: depth_cons.train() else: mvdnet.train() print("Training") end = time.time() for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv, tgt_depth) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] gt_nmap_var = Variable(gt_nmap.cuda()) ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) tgt_depth_var = Variable(tgt_depth.cuda()).cuda() # compute output pose = torch.cat(ref_poses_var, 1) # get mask mask = (tgt_depth_var <= args.nlabel * args.mindepth) & ( tgt_depth_var >= args.mindepth) & (tgt_depth_var == tgt_depth_var) mask.detach_() if mask.any() == 0: continue if args.train_cons: with torch.no_grad(): outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var) output_depth1 = outputs[0] nmap1 = outputs[1] else: outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var) output_depth1 = outputs[1] nmap1 = outputs[2] if args.train_cons: outputs = depth_cons(output_depth1, nmap1) nmap = outputs[:, 1:] depths = [outputs[:, 0]] else: nmap = nmap1.permute(0, 3, 1, 2) depths = [output_depth1.squeeze(1)] loss = 0. d_loss = 0. nmap_loss = 0. cons_loss = 0. for l, depth in enumerate(depths): output = torch.squeeze(depth, 1) d_loss = d_loss + F.smooth_l1_loss(output[mask], tgt_depth_var[mask]) n_mask = mask.unsqueeze(1).expand(-1, 3, -1, -1) nmap_loss = nmap_loss + F.smooth_l1_loss(nmap[n_mask], gt_nmap_var[n_mask]) if args.train_cons: cons_loss = cons_loss + cons_loss_( depths[-1].unsqueeze(1), tgt_depth_var.unsqueeze(1), nmap.clone(), intrinsics_var, mask.unsqueeze(1)) cons_losses.update(cons_loss.item(), args.batch_size) loss = loss + args.d_weight * d_loss + args.n_weight * nmap_loss + args.c_weight * cons_loss if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('total_loss', loss.item(), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) d_losses.update(d_loss.item(), args.batch_size) nmap_losses.update(nmap_loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss.item()]) if i % args.print_freq == 0: print( 'Train: Time {} Data {} Loss {} NmapLoss {} DLoss {} ConsLoss {}Iter {}/{} Epoch {}/{}' .format(batch_time, data_time, losses, nmap_losses, d_losses, cons_losses, i, len(train_loader), epoch, args.epochs)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger, output_writers=[]): global device batch_time = AverageMeter() losses = AverageMeter(i=4, precision=4) log_outputs = len(output_writers) > 0 # switch to evaluate mode disp_net.eval() pose_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) # compute output tgt_depth = [1 / disp_net(tgt_img)] ref_depths = [] for ref_img in ref_imgs: ref_depth = [1 / disp_net(ref_img)] ref_depths.append(ref_depth) if log_outputs and i < len(output_writers): if epoch == 0: output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0) output_writers[i].add_image( 'val Dispnet Output Normalized', tensor2array(1 / tgt_depth[0][0], max_value=None, colormap='magma'), epoch) output_writers[i].add_image( 'val Depth Output', tensor2array(tgt_depth[0][0], max_value=10), epoch) poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs, intrinsics) loss_1, loss_3 = compute_photo_and_geometry_loss( tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses, poses_inv, args.num_scales, args.with_ssim, args.with_mask, False, args.padding_mode) loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs) loss_1 = loss_1.item() loss_2 = loss_2.item() loss_3 = loss_3.item() loss = loss_1 losses.update([loss, loss_1, loss_2, loss_3]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) logger.valid_bar.update(len(val_loader)) return losses.avg, [ 'Total loss', 'Photo loss', 'Smooth loss', 'Consistency loss' ]
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, tb_writer, n_iter, torch_device): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.gt_pose_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, tgt_lf, ref_imgs, ref_lfs, intrinsics, intrinsics_inv, pose_gt) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(torch_device) ref_imgs = [img.to(torch_device) for img in ref_imgs] tgt_lf = tgt_lf.to(torch_device) ref_lfs = [lf.to(torch_device) for lf in ref_lfs] intrinsics = intrinsics.to(torch_device) pose_gt = pose_gt.to(torch_device) # compute output disparities = disp_net(tgt_lf) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_lf, ref_lfs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) pred_pose_magnitude = pose[:, :, :3].norm(dim=2) pose_gt_magnitude = pose_gt[:, :, :3].norm(dim=2) pose_loss = (pred_pose_magnitude - pose_gt_magnitude).abs().mean() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * pose_loss if log_losses: tb_writer.add_scalar('train/photometric_error', loss_1.item(), n_iter) tb_writer.add_scalar('train/smoothness_loss', loss_3.item(), n_iter) tb_writer.add_scalar('train/total_loss', loss.item(), n_iter) tb_writer.add_scalar('train/pose_loss', pose_loss.item(), n_iter) if w2 > 0: tb_writer.add_scalar('train/explanability_loss', loss_2.item(), n_iter) if log_output: tb_writer.add_image('train/Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped, diff, explainability_mask)): log_output_tensorboard(tb_writer, "train", 0, k, n_iter, *scaled_maps) break # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path + "/" + args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item() ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, output_writers=[]): batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = len(output_writers) > 0 w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) # compute output disp = disp_net(tgt_img_var) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var) loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.data[0] if w2 > 0: loss_2 = explainability_loss(explainability_mask).data[0] else: loss_2 = 0 loss_3 = smooth_loss(disp).data[0] if log_outputs and i % 100 == 0 and i / 100 < len( output_writers): # log first output of every 100 batch index = int(i // 100) if epoch == 0: for j, ref in enumerate(ref_imgs): output_writers[index].add_image('val Input {}'.format(j), tensor2array(tgt_img[0]), 0) output_writers[index].add_image('val Input {}'.format(j), tensor2array(ref[0]), 1) output_writers[index].add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), epoch) output_writers[index].add_image( 'val Depth Output', tensor2array(1. / disp.data[0].cpu(), max_value=10), epoch) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_var): ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j], intrinsics_var[:1], intrinsics_inv_var[:1], rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] output_writers[index].add_image( 'val Warped Outputs {}'.format(j), tensor2array(ref_warped.data.cpu()), epoch) output_writers[index].add_image( 'val Diff Outputs {}'.format(j), tensor2array( 0.5 * (tgt_img_var[0] - ref_warped).abs().data.cpu()), epoch) if explainability_mask is not None: output_writers[index].add_image( 'val Exp mask Outputs {}'.format(j), tensor2array(explainability_mask[0, j].data.cpu(), max_value=1, colormap='bone'), epoch) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.data.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): output_writers.add_histogram( '{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) output_writers[0].add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, ['Total loss', 'Photo loss', 'Exp loss']
def main(): global args args = parser.parse_args() args.pretrained_disp = Path(args.pretrained_disp) args.pretrained_pose = Path(args.pretrained_pose) args.pretrained_mask = Path(args.pretrained_mask) args.pretrained_flow = Path(args.pretrained_flow) if args.output_dir is not None: args.output_dir = Path(args.output_dir) args.output_dir.makedirs_p() image_dir = args.output_dir / 'images' gt_dir = args.output_dir / 'gt' mask_dir = args.output_dir / 'mask' viz_dir = args.output_dir / 'viz' image_dir.makedirs_p() gt_dir.makedirs_p() mask_dir.makedirs_p() viz_dir.makedirs_p() output_writer = SummaryWriter(args.output_dir) normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) flow_loader_h, flow_loader_w = 256, 832 valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) if args.dataset == "kitti2015": val_flow_set = ValidationFlow(root=args.kitti_dir, sequence_length=5, transform=valid_flow_transform) val_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) disp_net = getattr(models, args.dispnet)().cuda() pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda() mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda() flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() dispnet_weights = torch.load(args.pretrained_disp) posenet_weights = torch.load(args.pretrained_pose) masknet_weights = torch.load(args.pretrained_mask) flownet_weights = torch.load(args.pretrained_flow) disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) mask_net.load_state_dict(masknet_weights['state_dict']) disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() error_names = [ 'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask', 'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask' ] errors = AverageMeter(i=len(error_names)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt) in enumerate(tqdm(val_loader)): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True) disp = disp_net(tgt_img_var) depth = 1 / disp pose = pose_net(tgt_img_var, ref_imgs_var) explainability_mask = mask_net(tgt_img_var, ref_imgs_var) if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3]) else: flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var, intrinsics_inv_var) flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var, intrinsics_inv_var) rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * ( 1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).abs() #.normalize() rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * ( rigidity_mask_census_v).type_as(flow_fwd) rigidity_mask_combined = 1 - ( 1 - rigidity_mask.type_as(explainability_mask)) * ( 1 - rigidity_mask_census.type_as(explainability_mask)) obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as( flow_fwd).expand_as(flow_fwd) * flow_fwd flow_fwd_rigid = (rigidity_mask_combined > args.THRESH ).type_as(flow_cam).expand_as(flow_cam) * flow_cam total_flow = flow_fwd_rigid + flow_fwd_non_rigid rigidity_mask = rigidity_mask.type_as(flow_fwd) _epe_errors = compute_all_epes( flow_gt_var, flow_cam, flow_fwd, rigidity_mask_combined) + compute_all_epes( flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded)) errors.update(_epe_errors) tgt_img_np = tgt_img[0].numpy() rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy( ) gt_mask_np = obj_map_gt[0].numpy() if args.output_dir is not None: np.save(image_dir / str(i).zfill(3), tgt_img_np) np.save(gt_dir / str(i).zfill(3), gt_mask_np) np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np) if (args.output_dir is not None) and i % 10 == 0: ind = int(i // 10) output_writer.add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), ind) output_writer.add_image('val Input', tensor2array(tgt_img[0].cpu()), i) output_writer.add_image( 'val Total Flow Output', flow_to_image(tensor2array(total_flow.data[0].cpu())), ind) output_writer.add_image( 'val Rigid Flow Output', flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Non-rigid Flow Output', flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Rigidity Mask', tensor2array(rigidity_mask.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Census', tensor2array(rigidity_mask_census.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Combined', tensor2array(rigidity_mask_combined.data[0].cpu(), max_value=1, colormap='bone'), ind) tgt_img_viz = tensor2array(tgt_img[0].cpu()) depth_viz = tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone') mask_viz = tensor2array( rigidity_mask_census_soft.data[0].prod(dim=0).cpu(), max_value=1, colormap='bone') rigid_flow_viz = flow_to_image(tensor2array( flow_cam.data[0].cpu())) non_rigid_flow_viz = flow_to_image( tensor2array(flow_fwd_non_rigid.data[0].cpu())) total_flow_viz = flow_to_image( tensor2array(total_flow.data[0].cpu())) row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz)) row2_viz = np.hstack( (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz)) row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8')) row2_viz_im = Image.fromarray((row2_viz).astype('uint8')) row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png') row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png') print("Results") print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ". format(*error_names)) print( "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}" .format(*errors.avg))
def train(train_loader, alice_net, bob_net, mod_net, optimizer, epoch_size, logger=None, train_writer=None, mode='compete'): global args, n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) # switch to train mode alice_net.train() bob_net.train() mod_net.train() end = time.time() for i, (img, target) in enumerate(train_loader): # measure data loading time #mode = 'compete' if (i%2)==0 else 'collaborate' data_time.update(time.time() - end) img_var = Variable(img.cuda()) target_var = Variable(target.cuda()) pred_alice = alice_net(img_var) pred_bob = bob_net(img_var) pred_mod = mod_net(img_var) loss_alice = F.cross_entropy(pred_alice, target_var, reduce=False) loss_bob = F.cross_entropy(pred_bob, target_var, reduce=False) if mode == 'compete': if args.fix_bob: if args.DEBUG: print("Training Alice Only") loss = loss_alice.mean() elif args.fix_alice: loss = loss_bob.mean() else: if args.DEBUG: print("Training Both Alice and Bob") pred_mod_soft = Variable(F.sigmoid(pred_mod).data, requires_grad=False) loss = pred_mod_soft * loss_alice + (1 - pred_mod_soft) * loss_bob loss = loss.mean() elif mode == 'collaborate': loss_alice2 = Variable(loss_alice.data, requires_grad=False) loss_bob2 = Variable(loss_bob.data, requires_grad=False) loss1 = F.sigmoid(pred_mod) * loss_alice2 + ( 1 - F.sigmoid(pred_mod)) * loss_bob2 loss2 = collaboration_loss(pred_mod, loss_alice2, loss_bob2) loss = loss1.mean() + loss2.mean( ) + args.wr * mod_regularization_loss(pred_mod) if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('loss_alice', loss_alice.mean().item(), n_iter) train_writer.add_scalar('loss_bob', loss_bob.mean().item(), n_iter) train_writer.add_scalar('mod_mean', F.sigmoid(pred_mod).mean().item(), n_iter) train_writer.add_scalar('mod_var', F.sigmoid(pred_mod).var().item(), n_iter) train_writer.add_scalar('loss_regularization', mod_regularization_loss(pred_mod).item(), n_iter) if mode == 'compete': train_writer.add_scalar('competetion_loss', loss.item(), n_iter) elif mode == 'collaborate': train_writer.add_scalar('collaboration_loss', loss.item(), n_iter) # record loss losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_alice.mean().item(), loss_bob.mean().item() ]) if args.log_terminal: logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write( 'Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_with_gt(args, val_loader, mvdnet, epoch, output_writers=[]): batch_time = AverageMeter() error_names = [ 'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'mean_angle' ] test_error_names = [ 'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3', 'mean_angle' ] errors = AverageMeter(i=len(error_names)) test_errors = AverageMeter(i=len(test_error_names)) log_outputs = len(output_writers) > 0 output_dir = Path(args.output_dir) if not os.path.isdir(output_dir): os.mkdir(output_dir) # switch to evaluate mode mvdnet.eval() end = time.time() with torch.no_grad(): for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv, tgt_depth) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] gt_nmap_var = Variable(gt_nmap.cuda()) ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) tgt_depth_var = Variable(tgt_depth.cuda()) pose = torch.cat(ref_poses_var, 1) if (pose != pose).any(): continue if args.dataset == 'sceneflow': factor = (1.0 / args.scale) * intrinsics_var[:, 0, 0] / 1050.0 factor = factor.view(-1, 1, 1) else: factor = torch.ones( (tgt_depth_var.size(0), 1, 1)).type_as(tgt_depth_var) # get mask mask = (tgt_depth_var <= args.nlabel * args.mindepth * factor * 3) & (tgt_depth_var >= args.mindepth * factor) & ( tgt_depth_var == tgt_depth_var) if not mask.any(): continue output_depth, nmap = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var, factor=factor.unsqueeze(1)) output_disp = args.nlabel * args.mindepth / (output_depth) if args.dataset == 'sceneflow': output_disp = (args.nlabel * args.mindepth) * 3 / (output_depth) output_depth = (args.nlabel * 3) * (args.mindepth * factor) / output_disp tgt_disp_var = ((1.0 / args.scale) * intrinsics_var[:, 0, 0].view(-1, 1, 1) / tgt_depth_var) if args.dataset == 'sceneflow': output = torch.squeeze(output_disp.data.cpu(), 1) errors_ = compute_errors_train(tgt_disp_var.cpu(), output, mask) test_errors_ = list( compute_errors_test(tgt_disp_var.cpu()[mask], output[mask])) else: output = torch.squeeze(output_depth.data.cpu(), 1) errors_ = compute_errors_train(tgt_depth, output, mask) test_errors_ = list( compute_errors_test(tgt_depth[mask], output[mask])) n_mask = (gt_nmap_var.permute(0, 2, 3, 1)[0, :, :] != 0) n_mask = n_mask[:, :, 0] | n_mask[:, :, 1] | n_mask[:, :, 2] total_angles_m = compute_angles( gt_nmap_var.permute(0, 2, 3, 1)[0], nmap[0]) mask_angles = total_angles_m[n_mask] total_angles_m[~n_mask] = 0 errors_.append( torch.mean(mask_angles).item() ) #/mask_angles.size(0)#[torch.sum(mask_angles).item(), (mask_angles.size(0)), torch.sum(mask_angles < 7.5).item(), torch.sum(mask_angles < 15).item(), torch.sum(mask_angles < 30).item(), torch.sum(mask_angles < 45).item()] test_errors_.append(torch.mean(mask_angles).item()) errors.update(errors_) test_errors.update(test_errors_) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.output_print: np.save(output_dir / '{:04d}{}'.format(i, '_depth.npy'), output.numpy()[0]) plt.imsave(output_dir / '{:04d}_gt{}'.format(i, '.png'), tgt_depth.numpy()[0], cmap='rainbow') imsave(output_dir / '{:04d}_aimage{}'.format(i, '.png'), np.transpose(tgt_img.numpy()[0], (1, 2, 0))) np.save(output_dir / '{:04d}_cam{}'.format(i, '.npy'), intrinsics_var.cpu().numpy()[0]) np.save(output_dir / '{:04d}{}'.format(i, '_normal.npy'), nmap.cpu().numpy()[0]) if i % args.print_freq == 0: print( 'valid: Time {} Abs Error {:.4f} ({:.4f}) Abs angle Error {:.4f} ({:.4f}) Iter {}/{}' .format(batch_time, test_errors.val[0], test_errors.avg[0], test_errors.val[-1], test_errors.avg[-1], i, len(val_loader))) if args.output_print: np.savetxt(output_dir / args.ttype + 'errors.csv', test_errors.avg, fmt='%1.4f', delimiter=',') np.savetxt(output_dir / args.ttype + 'angle_errors.csv', test_errors.avg, fmt='%1.4f', delimiter=',') return errors.avg, error_names
def train(train_loader, model, optimizer, epoch, args, log, mp=None): '''train given model and dataloader''' batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() mixing_avg = [] # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(train_loader): data_time.update(time.time() - end) optimizer.zero_grad() input = input.cuda() target = target.long().cuda() unary = None noise = None adv_mask1 = 0 adv_mask2 = 0 # train with clean images if args.train == 'vanilla': input_var, target_var = Variable(input), Variable(target) output, reweighted_target = model(input_var, target_var) loss = bce_loss(softmax(output), reweighted_target) # train with mixup images elif args.train == 'mixup': # process for Puzzle Mix if args.graph: # whether to add adversarial noise or not if args.adv_p > 0: adv_mask1 = np.random.binomial(n=1, p=args.adv_p) adv_mask2 = np.random.binomial(n=1, p=args.adv_p) else: adv_mask1 = 0 adv_mask2 = 0 # random start if (adv_mask1 == 1 or adv_mask2 == 1): noise = torch.zeros_like(input).uniform_( -args.adv_eps / 255., args.adv_eps / 255.) input_orig = input * args.std + args.mean input_noise = input_orig + noise input_noise = torch.clamp(input_noise, 0, 1) noise = input_noise - input_orig input_noise = (input_noise - args.mean) / args.std input_var = Variable(input_noise, requires_grad=True) else: input_var = Variable(input, requires_grad=True) target_var = Variable(target) # calculate saliency (unary) if args.clean_lam == 0: model.eval() output = model(input_var) loss_batch = criterion_batch(output, target_var) else: model.train() output = model(input_var) loss_batch = 2 * args.clean_lam * criterion_batch( output, target_var) / args.num_classes loss_batch_mean = torch.mean(loss_batch, dim=0) loss_batch_mean.backward(retain_graph=True) unary = torch.sqrt(torch.mean(input_var.grad**2, dim=1)) # calculate adversarial noise if (adv_mask1 == 1 or adv_mask2 == 1): noise += (args.adv_eps + 2) / 255. * input_var.grad.sign() noise = torch.clamp(noise, -args.adv_eps / 255., args.adv_eps / 255.) adv_mix_coef = np.random.uniform(0, 1) noise = adv_mix_coef * noise if args.clean_lam == 0: model.train() optimizer.zero_grad() input_var, target_var = Variable(input), Variable(target) # perform mixup and calculate loss output, reweighted_target = model(input_var, target_var, mixup=True, args=args, grad=unary, noise=noise, adv_mask1=adv_mask1, adv_mask2=adv_mask2, mp=mp) loss = bce_loss(softmax(output), reweighted_target) # for manifold mixup elif args.train == 'mixup_hidden': input_var, target_var = Variable(input), Variable(target) output, reweighted_target = model(input_var, target_var, mixup_hidden=True, args=args) loss = bce_loss(softmax(output), reweighted_target) else: raise AssertionError('wrong train type!!') # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() print_log( ' **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}' .format(top1=top1, top5=top5, error1=100 - top1.avg), log) return top1.avg, top5.avg, losses.avg
def validate(val_loader, model, log, fgsm=False, eps=4, rand_init=False, mean=None, std=None): '''evaluate trained model''' losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() for i, (input, target) in enumerate(val_loader): if args.use_cuda: input = input.cuda() target = target.cuda() # check FGSM for adversarial training if fgsm: input_var = Variable(input, requires_grad=True) target_var = Variable(target) optimizer_input = torch.optim.SGD([input_var], lr=0.1) output = model(input_var) loss = criterion(output, target_var) optimizer_input.zero_grad() loss.backward() sign_data_grad = input_var.grad.sign() input = input * std + mean + eps / 255. * sign_data_grad input = torch.clamp(input, 0, 1) input = (input - mean) / std with torch.no_grad(): input_var = Variable(input) target_var = Variable(target) # compute output output = model(input_var) loss = criterion(output, target_var) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) if fgsm: print_log( 'Attack (eps : {}) Prec@1 {top1.avg:.2f}'.format(eps, top1=top1), log) else: print_log( ' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f} Loss: {losses.avg:.3f} ' .format(top1=top1, top5=top5, error1=100 - top1.avg, losses=losses), log) return top1.avg, losses.avg
def validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers=[]): global device batch_time = AverageMeter() error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] errors = AverageMeter(i=len(error_names)) log_outputs = len(output_writers) > 0 # switch to evaluate mode disp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, depth) in enumerate(val_loader): tgt_img = tgt_img.to(device) depth = depth.to(device) # compute output output_disp = disp_net(tgt_img) output_depth = 1 / output_disp[:, 0] if log_outputs and i < len(output_writers): if epoch == 0: output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0) depth_to_show = depth[0] output_writers[i].add_image( 'val target Depth', tensor2array(depth_to_show, max_value=10), epoch) depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1 / depth_to_show).clamp(0, 10) output_writers[i].add_image( 'val target Disparity Normalized', tensor2array(disp_to_show, max_value=None, colormap='bone'), epoch) output_writers[i].add_image( 'val Dispnet Output Normalized', tensor2array(output_disp[0], max_value=None, colormap='bone'), epoch) output_writers[i].add_image( 'val Depth Output', tensor2array(output_depth[0], max_value=3), epoch) #debug for the errors #************************************** # scale_factor = torch.div(torch.median(depth), torch.median(output_depth)) # #scale_factor = np.median(depth)/np.median(output_depth) # #sl_tensor=torch.tensor(scale_factor) # #print() # errors.update(compute_errors(depth, output_depth*scale_factor)) #************************************** #original errors.update(compute_errors(depth, output_depth)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} Abs Error {:.4f} ({:.4f})'.format( batch_time, errors.val[0], errors.avg[0])) logger.valid_bar.update(len(val_loader)) # debug # print(errors.avg) # print(error_names) return errors.avg, error_names
def train(train_loader, distilled_model, epoch, args): distilled_model.train() batch_time = AverageMeter() loss_teacher_rec = AverageMeter() loss_student_rec = AverageMeter() loss_student_perceptual = AverageMeter() loss_dehazing_network = AverageMeter() loss_psnr = AverageMeter() loss_ssim = AverageMeter() # Start counting time time_start = time.time() for i, item in enumerate(tqdm(train_loader)): gt, hazy = item["gt"], item["hazy"] if torch.cuda.is_available(): gt, hazy = gt.cuda(), hazy.cuda() loss = distilled_model.backward(gt, hazy, args) loss_teacher_rec.update(loss["teacher_rec_loss"].item(), gt.size(0)) loss_student_rec.update(loss["student_rec_loss"].item(), gt.size(0)) loss_student_perceptual.update(loss["perceptual_loss"].item(), gt.size(0)) loss_dehazing_network.update(loss["dehazing_loss"].item(), gt.size(0)) loss_psnr.update(loss["loss_psnr"].item(), gt.size(0)) loss_ssim.update(loss["loss_ssim"].item(), gt.size(0)) # time time_end = time.time() batch_time.update(time_end - time_start) time_start = time_end if (i + 1) % args.log_interval == 0: print( '[Train] Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Teacher Reconstruction Loss {loss_teacher.val:.4f} ({loss_teacher.avg:.4f})\t' 'Student Reconstruction Loss {loss_student.val:.4f} ({loss_student.avg:.4f})\t' 'Student Perceptual Loss {loss_perc.val:.4f} ({loss_perc.avg:.4f})\t' 'Dehazing Network Loss {loss_dehaze.val:.4f} ({loss_dehaze.avg:.4f})\t' 'PSNR {loss_psnr.val:.4f} ({loss_psnr.avg:.4f})\t' 'SSIM {loss_ssim.val:.4f} ({loss_ssim.avg:.4f})\t'.format( epoch + 1, i + 1, len(train_loader), batch_time=batch_time, loss_teacher=loss_teacher_rec, loss_student=loss_student_rec, loss_perc=loss_student_perceptual, loss_dehaze=loss_dehazing_network, loss_psnr=loss_psnr, loss_ssim=loss_ssim)) losses = { "teacher_rec_loss": loss_teacher_rec, "student_rec_loss": loss_student_rec, "perceptual_loss": loss_student_perceptual, "dehazing_loss": loss_dehazing_network, "loss_psnr": loss_psnr, "loss_ssim": loss_ssim } return losses
def adjust_shifts(args, train_set, adjust_loader, depth_net, pose_net, epoch, logger, training_writer, **env): batch_time = AverageMeter() data_time = AverageMeter() new_shifts = AverageMeter(args.sequence_length - 1, precision=2) pose_net.eval() depth_net.eval() upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size) end = time.time() mid_index = (args.sequence_length - 1) // 2 # we contrain mean value of depth net output from pair 0 and mid_index target_values = np.arange( -mid_index, mid_index + 1) / (args.target_mean_depth * mid_index) target_values = 1 / np.abs( np.concatenate( [target_values[:mid_index], target_values[mid_index + 1:]])) logger.reset_train_bar(len(adjust_loader)) for i, sample in enumerate(adjust_loader): index = sample['index'] # measure data loading time data_time.update(time.time() - end) imgs = torch.stack(sample['imgs'], dim=1).to(device) intrinsics = sample['intrinsics'].to(device) intrinsics_inv = sample['intrinsics_inv'].to(device) # compute output batch_size, seq = imgs.size()[:2] if args.network_input_size is not None: h, w = args.network_input_size downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area') poses = pose_net(downsample_imgs) # [B, seq, 6] else: poses = pose_net(imgs) pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] tgt_imgs = imgs[:, mid_index] # [B, 3, H, W] tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] compensated_poses = compensate_pose( pose_matrices, tgt_poses) # [B, seq, 3, 4] tgt_poses are now neutral pose ref_indices = list(range(args.sequence_length)) ref_indices.remove(mid_index) mean_depth_batch = [] for ref_index in ref_indices: prior_imgs = imgs[:, ref_index] prior_poses = compensated_poses[:, ref_index] # [B, 3, 4] prior_imgs_compensated = inverse_rotate(prior_imgs, prior_poses[:, :, :3], intrinsics, intrinsics_inv) input_pair = torch.cat([prior_imgs_compensated, tgt_imgs], dim=1) # [B, 6, W, H] depth = upsample_depth_net(input_pair) # [B, 1, H, W] mean_depth = depth.view(batch_size, -1).mean(-1).cpu().numpy() # B mean_depth_batch.append(mean_depth) for j, mean_values in zip(index, np.stack(mean_depth_batch, axis=-1)): ratio = mean_values / target_values # if mean value is too high, raise the shift, lower otherwise train_set.reset_shifts(j, ratio[:mid_index], ratio[mid_index:]) new_shifts.update(train_set.get_shifts(j)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.train_bar.update(i) if i % args.print_freq == 0: logger.train_writer.write('Adjustement:' 'Time {} Data {} shifts {}'.format( batch_time, data_time, new_shifts)) for i, shift in enumerate(new_shifts.avg): training_writer.add_scalar('shifts{}'.format(i), shift, epoch) return new_shifts.avg
def adjust_shifts(args, train_set, adjust_loader, pose_exp_net, epoch, logger, tb_writer): batch_time = AverageMeter() data_time = AverageMeter() new_shifts = AverageMeter(args.sequence_length - 1) pose_exp_net.train() poses = np.zeros(((len(adjust_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) mid_index = (args.sequence_length - 1) // 2 target_values = np.abs(np.arange( -mid_index, mid_index + 1)) * (args.target_displacement) target_values = np.concatenate( [target_values[:mid_index], target_values[mid_index + 1:]]) end = time.time() for i, (indices, tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(adjust_loader): # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] # compute output explainability_mask, pose_batch = pose_exp_net(tgt_img, ref_imgs) if i < len(adjust_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose_batch.cpu().reshape( -1, 6).numpy() for index, pose in zip(indices, pose_batch): displacements = pose[:, :3].norm(p=2, dim=1).cpu().numpy() ratio = target_values / displacements train_set.reset_shifts(index, ratio[:mid_index], ratio[mid_index:]) new_shifts.update(train_set.get_shifts(index)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.train_bar.update(i) if i % args.print_freq == 0: logger.train_writer.write('Adjustement:' 'Time {} Data {} shifts {}'.format( batch_time, data_time, new_shifts)) prefix = 'train poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) return new_shifts.avg
def validate_with_gt(args, val_loader, mvdnet, depth_cons, epoch, output_writers=[]): batch_time = AverageMeter() error_names = [ 'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'mean_angle' ] test_error_names = [ 'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3', 'mean_angle' ] test_error_names1 = [ 'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3', 'mean_angle' ] errors = AverageMeter(i=len(error_names)) test_errors = AverageMeter(i=len(test_error_names)) test_errors1 = AverageMeter(i=len(test_error_names1)) log_outputs = len(output_writers) > 0 # switch to evaluate mode if args.train_cons: depth_cons.eval() else: mvdnet.eval() end = time.time() with torch.no_grad(): for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv, tgt_depth) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] gt_nmap_var = Variable(gt_nmap.cuda()) ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) tgt_depth_var = Variable(tgt_depth.cuda()) pose = torch.cat(ref_poses_var, 1) if (pose != pose).any(): continue outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var) output_depth = outputs[0] output_depth1 = output_depth.clone() nmap = outputs[1] nmap1 = nmap.clone() output_depth1 = output_depth.clone() if args.train_cons: outputs = depth_cons(output_depth, nmap.permute(0, 3, 1, 2)) nmap = outputs[:, 1:].permute(0, 2, 3, 1) output_depth = outputs[:, 0].unsqueeze(1) mask = (tgt_depth <= args.nlabel * args.mindepth) & ( tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth) #mask = (tgt_depth <= 10) & (tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth) #for DeMoN testing, to compare against DPSNet you might need to turn on this for fair comparison if not mask.any(): continue output_depth1_ = torch.squeeze(output_depth1.data.cpu(), 1) output_depth_ = torch.squeeze(output_depth.data.cpu(), 1) errors_ = compute_errors_train(tgt_depth, output_depth_, mask) test_errors_ = list( compute_errors_test(tgt_depth[mask], output_depth_[mask])) test_errors1_ = list( compute_errors_test(tgt_depth[mask], output_depth1_[mask])) n_mask = (gt_nmap_var.permute(0, 2, 3, 1)[0, :, :] != 0) n_mask = n_mask[:, :, 0] | n_mask[:, :, 1] | n_mask[:, :, 2] total_angles_m = compute_angles( gt_nmap_var.permute(0, 2, 3, 1)[0], nmap[0]) total_angles_m1 = compute_angles( gt_nmap_var.permute(0, 2, 3, 1)[0], nmap1[0]) mask_angles = total_angles_m[n_mask] mask_angles1 = total_angles_m1[n_mask] total_angles_m[~n_mask] = 0 total_angles_m1[~n_mask] = 0 errors_.append( torch.mean(mask_angles).item() ) #/mask_angles.size(0)#[torch.sum(mask_angles).item(), (mask_angles.size(0)), torch.sum(mask_angles < 7.5).item(), torch.sum(mask_angles < 15).item(), torch.sum(mask_angles < 30).item(), torch.sum(mask_angles < 45).item()] test_errors_.append(torch.mean(mask_angles).item()) test_errors1_.append(torch.mean(mask_angles1).item()) errors.update(errors_) test_errors.update(test_errors_) test_errors1.update(test_errors1_) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 or i == len(val_loader) - 1: if args.train_cons: print( 'valid: Time {} Prev Error {:.4f}({:.4f}) Curr Error {:.4f} ({:.4f}) Prev angle Error {:.4f} ({:.4f}) Curr angle Error {:.4f} ({:.4f}) Iter {}/{}' .format(batch_time, test_errors1.val[0], test_errors1.avg[0], test_errors.val[0], test_errors.avg[0], test_errors1.val[-1], test_errors1.avg[-1], test_errors.val[-1], test_errors.avg[-1], i, len(val_loader))) else: print( 'valid: Time {} Rel Error {:.4f} ({:.4f}) Angle Error {:.4f} ({:.4f}) Iter {}/{}' .format(batch_time, test_errors.val[0], test_errors.avg[0], test_errors.val[-1], test_errors.avg[-1], i, len(val_loader))) if args.output_print: output_dir = Path(args.output_dir) if not os.path.isdir(output_dir): os.mkdir(output_dir) plt.imsave(output_dir / '{:04d}_map{}'.format(i, '_dps.png'), output_depth_.numpy()[0], cmap='rainbow') np.save(output_dir / '{:04d}{}'.format(i, '_dps.npy'), output_depth_.numpy()[0]) if args.train_cons: plt.imsave(output_dir / '{:04d}_map{}'.format(i, '_prev.png'), output_depth1_.numpy()[0], cmap='rainbow') np.save(output_dir / '{:04d}{}'.format(i, '_prev.npy'), output_depth1_.numpy()[0]) # np.save(output_dir/'{:04d}{}'.format(i,'_gt.npy'),tgt_depth.numpy()[0]) # imsave(output_dir/'{:04d}_aimage{}'.format(i,'.png'), np.transpose(tgt_img.numpy()[0],(1,2,0))) # np.save(output_dir/'{:04d}_cam{}'.format(i,'.npy'),intrinsics_var.cpu().numpy()[0]) if args.output_print: np.savetxt(output_dir / args.ttype + 'errors.csv', test_errors.avg, fmt='%1.4f', delimiter=',') np.savetxt(output_dir / args.ttype + 'prev_errors.csv', test_errors1.avg, fmt='%1.4f', delimiter=',') return errors.avg, error_names
def train(args, train_loader, disp_net, pose_exp_net, seg_net, optimizer, epoch_size, logger, tb_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.seg_loss # switch to train mode disp_net.train() pose_exp_net.train() seg_net.eval() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output tgt_seg = seg_net(tgt_img) edge = tgt_seg[:, :, 0:-1, :] - tgt_seg[:, :, 1:, :] disparities = disp_net(tgt_img, edge) ref_seg = [seg_net(i) for i in ref_imgs] depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img, ref_imgs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_seg, warped_seg, diff_seg = photometric_reconstruction_loss( tgt_seg, ref_seg, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(depth) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_seg if log_losses: tb_writer.add_scalar('seg_loss', loss_seg.item(), n_iter) tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter) if w2 > 0: tb_writer.add_scalar('explanability_loss', loss_2.item(), n_iter) tb_writer.add_scalar('disparity_smoothness_loss', loss_3.item(), n_iter) tb_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: tb_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_maps in enumerate( zip(depth, disparities, warped_seg, diff_seg, explainability_mask)): log_output_tensorboard(tb_writer, "train", 0, " {}".format(k), n_iter, *scaled_maps) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( [loss.item(), loss_2.item() if w2 > 0 else 0, loss_3.item()]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def main(): global args args = parser.parse_args() args.pretrained_disp = Path(args.pretrained_disp) args.pretrained_pose = Path(args.pretrained_pose) # args.pretrained_mask = Path(args.pretrained_mask) args.pretrained_flow = Path(args.pretrained_flow) if args.output_dir is not None: args.output_dir = Path(args.output_dir) args.output_dir.makedirs_p() image_dir = args.output_dir / 'images' gt_dir = args.output_dir / 'gt' mask_dir = args.output_dir / 'mask' viz_dir = args.output_dir / 'viz' image_dir.makedirs_p() gt_dir.makedirs_p() mask_dir.makedirs_p() viz_dir.makedirs_p() output_writer = SummaryWriter(args.output_dir) normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) flow_loader_h, flow_loader_w = 256, 832 valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) if args.dataset == "kitti2015": val_flow_set = ValidationFlow(root=args.kitti_dir, sequence_length=3, transform=valid_flow_transform) val_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) disp_net = getattr(models, args.dispnet)().cuda() pose_net = getattr(models, args.posenet)(nb_ref_imgs=2).cuda() # mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda() flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() dispnet_weights = torch.load(args.pretrained_disp) posenet_weights = torch.load(args.pretrained_pose) # masknet_weights = torch.load(args.pretrained_mask) flownet_weights = torch.load(args.pretrained_flow) disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) # mask_net.load_state_dict(masknet_weights['state_dict']) disp_net.eval() pose_net.eval() # mask_net.eval() flow_net.eval() error_names = [ 'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask', 'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask' ] errors = AverageMeter(i=len(error_names)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt) in enumerate(tqdm(val_loader)): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True) disp = disp_net(tgt_img_var) depth = 1 / disp pose = pose_net(tgt_img_var, ref_imgs_var) # explainability_mask = mask_net(tgt_img_var, ref_imgs_var) # print(len(explainability_mask)) if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var) else: flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var, intrinsics_inv_var) # flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:,1], intrinsics_var, intrinsics_inv_var) #--------------------------------------------------------------- flows_cam_fwd = [ pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics_var, intrinsics_inv_var) for depth_ in depth ] flows_cam_bwd = [ pose2flow(depth_.squeeze(1), pose[:, 0], intrinsics_var, intrinsics_inv_var) for depth_ in depth ] flow_fwd_list = [] flow_fwd_list.append(flow_fwd) flow_bwd_list = [] flow_bwd_list.append(flow_bwd) rigidity_mask_fwd = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd_list, flow_bwd_list, tgt_img_var, ref_imgs_var[1], ref_imgs_var[0], wssim=0.85, wrig=1.0, ws=0.1)[0] del flow_fwd_list del flow_bwd_list #-------------------------------------------------------------- #rigidity_mask = 1 - (1-explainability_mask[:,1])*(1-explainability_mask[:,2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).abs() #.normalize() #rigidity_mask_census_u = rigidity_mask_census_soft[:,0] < args.THRESH #rigidity_mask_census_v = rigidity_mask_census_soft[:,1] < args.THRESH #rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (rigidity_mask_census_v).type_as(flow_fwd) # rigidity_mask_census = ( torch.pow( (torch.pow(rigidity_mask_census_soft[:,0],2) + torch.pow(rigidity_mask_census_soft[:,1] , 2)), 0.5) < args.THRESH ).type_as(flow_fwd) THRESH_1 = 1 THRESH_2 = 1 rigidity_mask_census = ( (torch.pow(rigidity_mask_census_soft[:, 0], 2) + torch.pow(rigidity_mask_census_soft[:, 1], 2)) < THRESH_1 * (flow_cam.pow(2).sum(dim=1) + flow_fwd.pow(2).sum(dim=1)) + THRESH_2).type_as(flow_fwd) # rigidity_mask_census = torch.zeros_like(rigidity_mask_census) rigidity_mask_fwd = torch.zeros_like(rigidity_mask_fwd) rigidity_mask_combined = 1 - (1 - rigidity_mask_fwd) * ( 1 - rigidity_mask_census) # obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as( flow_fwd).expand_as(flow_fwd) * flow_fwd flow_fwd_rigid = (rigidity_mask_combined > args.THRESH ).type_as(flow_cam).expand_as(flow_cam) * flow_cam total_flow = flow_fwd_rigid + flow_fwd_non_rigid # rigidity_mask = rigidity_mask.type_as(flow_fwd) _epe_errors = compute_all_epes( flow_gt_var, flow_cam, flow_fwd, torch.zeros_like(rigidity_mask_combined)) + compute_all_epes( flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded)) errors.update(_epe_errors) tgt_img_np = tgt_img[0].numpy() rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy( ) gt_mask_np = obj_map_gt[0].numpy() if args.output_dir is not None: np.save(image_dir / str(i).zfill(3), tgt_img_np) np.save(gt_dir / str(i).zfill(3), gt_mask_np) np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np) if (args.output_dir is not None): tmp1 = flow_fwd.data[0].permute(1, 2, 0).cpu().numpy() tmp1 = flow_2_image(tmp1) scipy.misc.imsave(viz_dir / str(i).zfill(3) + 'flow.png', tmp1) print("Results") print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ". format(*error_names)) print( "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}" .format(*errors.avg))
def validate_with_gt(args, val_loader, disp_net, segnet, epoch, logger, tb_writer, sample_nb_to_log=3): global device batch_time = AverageMeter() error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] errors = AverageMeter(i=len(error_names)) log_outputs = sample_nb_to_log > 0 # switch to evaluate mode disp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, depth) in enumerate(val_loader): tgt_img = tgt_img.to(device) depth = depth.to(device) # compute output tgt_seg = segnet(tgt_img) edge = tgt_seg[:, :, 0:-1, :] - tgt_seg[:, :, 1:, :] output_disp = disp_net(tgt_img, edge) output_depth = 1 / output_disp[:, 0] if log_outputs and i < sample_nb_to_log: if epoch == 0: tb_writer.add_image('val Input/{}'.format(i), tensor2array(tgt_img[0]), 0) depth_to_show = depth[0] tb_writer.add_image( 'val target Depth Normalized/{}'.format(i), tensor2array(depth_to_show, max_value=None), epoch) depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1 / depth_to_show).clamp(0, 10) tb_writer.add_image( 'val target Disparity Normalized/{}'.format(i), tensor2array(disp_to_show, max_value=None, colormap='magma'), epoch) tb_writer.add_image( 'val Dispnet Output Normalized/{}'.format(i), tensor2array(output_disp[0], max_value=None, colormap='magma'), epoch) tb_writer.add_image('val Depth Output Normalized/{}'.format(i), tensor2array(output_depth[0], max_value=None), epoch) errors.update(compute_errors(depth, output_depth)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} Abs Error {:.4f} ({:.4f})'.format( batch_time, errors.val[0], errors.avg[0])) logger.valid_bar.update(len(val_loader)) return errors.avg, error_names
def train_one_epoch(args, train_loader, depth_net, pose_net, optimizer, epoch, n_iter, logger, training_writer, **env): global device logger.reset_train_bar() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim e1, e2 = args.training_milestones # switch to train mode depth_net.train() pose_net.train() upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size) end = time.time() logger.train_bar.update(0) for i, sample in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 # measure data loading time data_time.update(time.time() - end) imgs = torch.stack(sample['imgs'], dim=1).to(device) intrinsics = sample['intrinsics'].to(device) intrinsics_inv = sample['intrinsics_inv'].to(device) batch_size, seq = imgs.size()[:2] if args.network_input_size is not None: h, w = args.network_input_size downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area') poses = pose_net(downsample_imgs) # [B, seq, 6] else: poses = pose_net(imgs) pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] total_indices = torch.arange(seq).long().to(device).unsqueeze( 0).expand(batch_size, seq) batch_range = torch.arange(batch_size).long().to(device) ''' for each element of the batch select a random picture in the sequence to which we will compute the depth, all poses are then converted so that pose of this very picture is exactly identity. At first this image is always in the middle of the sequence''' if epoch > e2: tgt_id = torch.floor(torch.rand(batch_size) * seq).long().to(device) else: tgt_id = torch.zeros(batch_size).long().to( device) + args.sequence_length // 2 ''' Select what other picture we are going to feed DepthNet, it must not be the same as tgt_id. At first, it's always first picture of the sequence, it is randomly chosen when first training milestone is reached ''' ref_indices = total_indices[total_indices != tgt_id.unsqueeze(1)].view( batch_size, seq - 1) if epoch > e1: prior_id = torch.floor(torch.rand(batch_size) * (seq - 1)).long().to(device) else: prior_id = torch.zeros(batch_size).long().to(device) prior_id = ref_indices[batch_range, prior_id] tgt_imgs = imgs[batch_range, tgt_id] # [B, 3, H, W] tgt_poses = pose_matrices[batch_range, tgt_id] # [B, 3, 4] prior_imgs = imgs[batch_range, prior_id] compensated_poses = compensate_pose( pose_matrices, tgt_poses) # [B, seq, 3, 4] tgt_poses are now neutral pose prior_poses = compensated_poses[batch_range, prior_id] # [B, 3, 4] if args.supervise_pose: from_GT = invert_mat(sample['pose']).to(device) compensated_GT_poses = compensate_pose( from_GT, from_GT[batch_range, tgt_id]) prior_GT_poses = compensated_GT_poses[batch_range, prior_id] prior_imgs_compensated = inverse_rotate(prior_imgs, prior_GT_poses[:, :, :-1], intrinsics, intrinsics_inv) else: prior_imgs_compensated = inverse_rotate(prior_imgs, prior_poses[:, :, :-1], intrinsics, intrinsics_inv) input_pair = torch.cat([prior_imgs_compensated, tgt_imgs], dim=1) # [B, 6, W, H] depth = upsample_depth_net(input_pair) # depth = [sample['depth'].to(device).unsqueeze(1) * 3 / abs(tgt_id[0] - prior_id[0])] # depth.append(torch.nn.functional.interpolate(depth[0], scale_factor=2)) disparities = [1 / d for d in depth] predicted_magnitude = prior_poses[:, :, -1:].norm(p=2, dim=1, keepdim=True).unsqueeze(1) scale_factor = args.nominal_displacement / (predicted_magnitude + 1e-5) normalized_translation = compensated_poses[:, :, :, -1:] * scale_factor # [B, seq_length-1, 3] new_pose_matrices = torch.cat( [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1) biggest_scale = depth[0].size(-1) loss_1 = 0 for k, scaled_depth in enumerate(depth): size_ratio = scaled_depth.size(-1) / biggest_scale loss, diff_maps, warped_imgs = photometric_reconstruction_loss( imgs, tgt_id, ref_indices, scaled_depth, new_pose_matrices, intrinsics, intrinsics_inv, args.rotation_mode, ssim_weight=w3) loss_1 += loss * size_ratio if log_output: training_writer.add_image( 'train Dispnet Output Normalized scale {}'.format(k), tensor2array(disparities[k][0], max_value=None, colormap='bone'), n_iter) training_writer.add_image( 'train Depth Output scale {}'.format(k), tensor2array(scaled_depth[0], max_value=args.max_depth), n_iter) for j, (diff, warped) in enumerate(zip(diff_maps, warped_imgs)): training_writer.add_image( 'train Warped Outputs {} {}'.format(k, j), tensor2array(warped[0]), n_iter) training_writer.add_image( 'train Diff Outputs {} {}'.format(k, j), tensor2array(diff.abs()[0] - 1), n_iter) loss_2 = texture_aware_smooth_loss( depth, tgt_imgs if args.texture_loss else None) loss = w1 * loss_1 + w2 * loss_2 if args.supervise_pose: loss += (from_GT[:, :, :, :3] - pose_matrices[:, :, :, :3]).abs().mean() if log_losses: training_writer.add_scalar('photometric_error', loss_1.item(), n_iter) training_writer.add_scalar('disparity_smoothness_loss', loss_2.item(), n_iter) training_writer.add_scalar('total_loss', loss.item(), n_iter) if log_output: nominal_translation_magnitude = poses[:, -2, :3].norm(p=2, dim=-1) # last pose is always identity and penultimate translation magnitude is always 1, so you don't need to log them for j in range(args.sequence_length - 2): trans_mag = poses[:, j, :3].norm(p=2, dim=-1) training_writer.add_histogram( 'tr {}'.format(j), (trans_mag / nominal_translation_magnitude).detach().cpu().numpy(), n_iter) for j in range(args.sequence_length - 1): # TODO log a better value : this is magnitude of vector (yaw, pitch, roll) which is not a physical value rot_mag = poses[:, j, 3:].norm(p=2, dim=-1) training_writer.add_histogram('rot {}'.format(j), rot_mag.detach().cpu().numpy(), n_iter) training_writer.add_image('train Input', tensor2array(tgt_imgs[0]), n_iter) # record loss for average meter losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss.item(), loss_1.item(), loss_2.item()]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= args.epoch_size - 1: break n_iter += 1 return losses.avg[0], n_iter
def train(args, train_loader, mvdnet, optimizer, epoch_size, train_writer, epoch): global n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) d_losses = AverageMeter(precision=4) nmap_losses = AverageMeter(precision=4) # switch to training mode mvdnet.train() print("Training") end = time.time() for i, (tgt_img, ref_imgs, gt_nmap, ref_poses, intrinsics, intrinsics_inv, tgt_depth) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] gt_nmap_var = Variable(gt_nmap.cuda()) ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) tgt_depth_var = Variable(tgt_depth.cuda()).cuda() # compute output pose = torch.cat(ref_poses_var, 1) if args.dataset == 'sceneflow': factor = (1.0 / args.scale) * intrinsics_var[:, 0, 0] / 1050.0 factor = factor.view(-1, 1, 1) else: factor = torch.ones( (tgt_depth_var.size(0), 1, 1)).type_as(tgt_depth_var) # get mask mask = (tgt_depth_var <= args.nlabel * args.mindepth * factor * 3) & ( tgt_depth_var >= args.mindepth * factor) & (tgt_depth_var == tgt_depth_var) mask.detach_() if mask.any() == 0: continue targetimg = inverse_warp(ref_imgs_var[0], tgt_depth_var.unsqueeze(1), pose[:, 0], intrinsics_var, intrinsics_inv_var) #[B,CH,D,H,W,1] outputs = mvdnet(tgt_img_var, ref_imgs_var, pose, intrinsics_var, intrinsics_inv_var, factor=factor.unsqueeze(1)) nmap = outputs[2].permute(0, 3, 1, 2) depths = outputs[0:2] disps = [args.mindepth * args.nlabel / (depth) for depth in depths] # correct disps if args.dataset == 'sceneflow': disps = [(args.mindepth * args.nlabel) * 3 / (depth) for depth in depths] # correct disps depths = [(args.mindepth * factor) * (args.nlabel * 3) / disp for disp in disps] loss = 0. d_loss = 0. nmap_loss = 0. if args.dataset == 'sceneflow': tgt_disp_var = ((1.0 / args.scale) * intrinsics_var[:, 0, 0].view(-1, 1, 1) / tgt_depth_var) for l, disp in enumerate(disps): output = torch.squeeze(disp, 1) d_loss = d_loss + F.smooth_l1_loss(output[mask], tgt_disp_var[mask]) * pow( 0.7, len(disps) - l - 1) else: for l, depth in enumerate(depths): output = torch.squeeze(depth, 1) d_loss = d_loss + F.smooth_l1_loss(output[mask], tgt_depth_var[mask]) * pow( 0.7, len(depths) - l - 1) n_mask = mask.unsqueeze(1).expand(-1, 3, -1, -1) nmap_loss = nmap_loss + F.smooth_l1_loss(nmap[n_mask], gt_nmap_var[n_mask]) loss = loss + args.d_weight * d_loss + args.n_weight * nmap_loss if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('total_loss', loss.item(), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) d_losses.update(d_loss.item(), args.batch_size) nmap_losses.update(nmap_loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([loss.item()]) if i % args.print_freq == 0: print( 'Train: Time {} Data {} Loss {} NmapLoss {} DLoss {} Iter {}/{} Epoch {}/{}' .format(batch_time, data_time, losses, nmap_losses, d_losses, i, len(train_loader), epoch, args.epochs)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_without_gt(args, val_loader, depth_net, pose_net, epoch, logger, output_writers=[], **env): global device batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim if args.log_output: poses_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros( ((len(val_loader) - 1) * args.test_batch_size * 3)) # switch to evaluate mode depth_net.eval() pose_net.eval() upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size) end = time.time() logger.valid_bar.update(0) for i, sample in enumerate(val_loader): log_output = i < len(output_writers) imgs = torch.stack(sample['imgs'], dim=1).to(device) intrinsics = sample['intrinsics'].to(device) intrinsics_inv = sample['intrinsics_inv'].to(device) if epoch == 1 and log_output: for j, img in enumerate(sample['imgs']): output_writers[i].add_image('val Input', tensor2array(img[0]), j) batch_size, seq = imgs.size()[:2] if args.network_input_size is not None: h, w = args.network_input_size downsample_imgs = F.interpolate(imgs, (3, h, w), mode='area') poses = pose_net(downsample_imgs) # [B, seq, 6] else: poses = pose_net(imgs) pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] mid_index = (args.sequence_length - 1) // 2 tgt_imgs = imgs[:, mid_index] # [B, 3, H, W] tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] compensated_poses = compensate_pose( pose_matrices, tgt_poses) # [B, seq, 3, 4] tgt_poses are now neutral pose ref_indices = list(range(args.sequence_length)) ref_indices.remove(mid_index) loss_1 = 0 loss_2 = 0 for ref_index in ref_indices: prior_imgs = imgs[:, ref_index] prior_poses = compensated_poses[:, ref_index] # [B, 3, 4] prior_imgs_compensated = inverse_rotate(prior_imgs, prior_poses[:, :, :3], intrinsics, intrinsics_inv) input_pair = torch.cat([prior_imgs_compensated, tgt_imgs], dim=1) # [B, 6, W, H] predicted_magnitude = prior_poses[:, :, -1:].norm( p=2, dim=1, keepdim=True).unsqueeze(1) # [B, 1, 1, 1] scale_factor = args.nominal_displacement / predicted_magnitude normalized_translation = compensated_poses[:, :, :, -1:] * scale_factor # [B, seq, 3, 1] new_pose_matrices = torch.cat( [compensated_poses[:, :, :, :-1], normalized_translation], dim=-1) depth = upsample_depth_net(input_pair) disparity = 1 / depth total_indices = torch.arange(seq).long().unsqueeze(0).expand( batch_size, seq).to(device) tgt_id = total_indices[:, mid_index] ref_indices = total_indices[ total_indices != tgt_id.unsqueeze(1)].view( batch_size, seq - 1) photo_loss, diff_maps, warped_imgs = photometric_reconstruction_loss( imgs, tgt_id, ref_indices, depth, new_pose_matrices, intrinsics, intrinsics_inv, args.rotation_mode, ssim_weight=w3) loss_1 += photo_loss if log_output: output_writers[i].add_image( 'val Dispnet Output Normalized {}'.format(ref_index), tensor2array(disparity[0], max_value=None, colormap='bone'), epoch) output_writers[i].add_image( 'val Depth Output {}'.format(ref_index), tensor2array(depth[0].cpu(), max_value=args.max_depth), epoch) for j, (diff, warped) in enumerate(zip(diff_maps, warped_imgs)): output_writers[i].add_image( 'val Warped Outputs {} {}'.format(j, ref_index), tensor2array(warped[0]), epoch) output_writers[i].add_image( 'val Diff Outputs {} {}'.format(j, ref_index), tensor2array(diff[0].abs() - 1), epoch) loss_2 += texture_aware_smooth_loss( disparity, tgt_imgs if args.texture_loss else None) if args.log_output and i < len(val_loader) - 1: step = args.test_batch_size * (args.sequence_length - 1) poses_values[i * step:(i + 1) * step] = poses[:, :-1].cpu().view( -1, 6).numpy() step = args.test_batch_size * 3 disp_unraveled = disparity.cpu().view(args.test_batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 losses.update([loss.item(), loss_1.item(), loss_2.item()]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if args.log_output: rot_coeffs = ['rx', 'ry', 'rz'] if args.rotation_mode == 'euler' else [ 'qx', 'qy', 'qz' ] tr_coeffs = ['tx', 'ty', 'tz'] for k, (coeff_name) in enumerate(tr_coeffs + rot_coeffs): output_writers[0].add_histogram('val poses_{}'.format(coeff_name), poses_values[:, k], epoch) output_writers[0].add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return OrderedDict( zip(['Total loss', 'Photo loss', 'Smooth loss'], losses.avg))
def train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch_size, logger, train_writer): global n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight # switch to train mode disp_net.train() pose_exp_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) # compute output disparities = disp_net(tgt_img_var) depth = [1 / disp for disp in disparities] explainability_mask, pose = pose_exp_net(tgt_img_var, ref_imgs_var) loss_1 = photometric_reconstruction_loss(tgt_img_var, ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) if w2 > 0: loss_2 = explainability_loss(explainability_mask) else: loss_2 = 0 loss_3 = smooth_loss(disparities) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('photometric_error', loss_1.data[0], n_iter) if w2 > 0: train_writer.add_scalar('explanability_loss', loss_2.data[0], n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_3.data[0], n_iter) train_writer.add_scalar('total_loss', loss.data[0], n_iter) if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) for k, scaled_depth in enumerate(depth): train_writer.add_image( 'train Dispnet Output Normalized {}'.format(k), tensor2array(disparities[k].data[0].cpu(), max_value=None, colormap='bone'), n_iter) train_writer.add_image( 'train Depth Output {}'.format(k), tensor2array(1 / disparities[k].data[0].cpu(), max_value=10), n_iter) b, _, h, w = scaled_depth.size() downscale = tgt_img_var.size(2) / h tgt_img_scaled = nn.functional.adaptive_avg_pool2d( tgt_img_var, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs_var ] intrinsics_scaled = torch.cat( (intrinsics_var[:, 0:2] / downscale, intrinsics_var[:, 2:]), dim=1) intrinsics_scaled_inv = torch.cat( (intrinsics_inv_var[:, :, 0:2] * downscale, intrinsics_inv_var[:, :, 2:]), dim=2) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_scaled): ref_warped = inverse_warp( ref, scaled_depth[:, 0], pose[:, j], intrinsics_scaled, intrinsics_scaled_inv, rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] train_writer.add_image( 'train Warped Outputs {} {}'.format(k, j), tensor2array(ref_warped.data.cpu()), n_iter) train_writer.add_image( 'train Diff Outputs {} {}'.format(k, j), tensor2array( 0.5 * (tgt_img_scaled[0] - ref_warped).abs().data.cpu()), n_iter) if explainability_mask[k] is not None: train_writer.add_image( 'train Exp mask Outputs {} {}'.format(k, j), tensor2array(explainability_mask[k][0, j].data.cpu(), max_value=1, colormap='bone'), n_iter) # record loss and EPE losses.update(loss.data[0], args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.data[0], loss_1.data[0], loss_2.data[0] if w2 > 0 else 0, loss_3.data[0] ]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def validate_with_gt(args, val_loader, depth_net, pose_net, epoch, logger, output_writers=[], **env): global device batch_time = AverageMeter() depth_error_names = ['abs diff', 'abs rel', 'sq rel', 'a1', 'a2', 'a3'] stab_depth_errors = AverageMeter(i=len(depth_error_names)) unstab_depth_errors = AverageMeter(i=len(depth_error_names)) pose_error_names = ['Absolute Trajectory Error', 'Rotation Error'] pose_errors = AverageMeter(i=len(pose_error_names)) # switch to evaluate mode depth_net.eval() pose_net.eval() end = time.time() logger.valid_bar.update(0) for i, sample in enumerate(val_loader): log_output = i < len(output_writers) imgs = torch.stack(sample['imgs'], dim=1).to(device) batch_size, seq, c, h, w = imgs.size() intrinsics = sample['intrinsics'].to(device) intrinsics_inv = sample['intrinsics_inv'].to(device) if args.network_input_size is not None: imgs = F.interpolate(imgs, (c, *args.network_input_size), mode='area') downscale = h / args.network_input_size[0] intrinsics = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) intrinsics_inv = torch.cat( (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :, 2:]), dim=2) GT_depth = sample['depth'].to(device) GT_pose = sample['pose'].to(device) mid_index = (args.sequence_length - 1) // 2 tgt_img = imgs[:, mid_index] if epoch == 1 and log_output: for j, img in enumerate(sample['imgs']): output_writers[i].add_image('val Input', tensor2array(img[0]), j) depth_to_show = GT_depth[0].cpu() # KITTI Like data routine to discard invalid data depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1 / depth_to_show).clamp(0, 10) output_writers[i].add_image( 'val target Disparity Normalized', tensor2array(disp_to_show, max_value=None, colormap='bone'), epoch) poses = pose_net(imgs) pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] inverted_pose_matrices = invert_mat(pose_matrices) pose_errors.update( compute_pose_error(GT_pose[:, :-1], inverted_pose_matrices.data[:, :-1])) tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] compensated_predicted_poses = compensate_pose(pose_matrices, tgt_poses) compensated_GT_poses = compensate_pose(GT_pose, GT_pose[:, mid_index]) for j in range(args.sequence_length): if j == mid_index: if log_output and epoch == 1: output_writers[i].add_image( 'val Input Stabilized', tensor2array(sample['imgs'][j][0]), j) continue '''compute displacement magnitude for each element of batch, and rescale depth accordingly.''' prior_img = imgs[:, j] displacement = compensated_GT_poses[:, j, :, -1] # [B,3] displacement_magnitude = displacement.norm(p=2, dim=1) # [B] current_GT_depth = GT_depth * args.nominal_displacement / displacement_magnitude.view( -1, 1, 1) prior_predicted_pose = compensated_predicted_poses[:, j] # [B, 3, 4] prior_GT_pose = compensated_GT_poses[:, j] prior_predicted_rot = prior_predicted_pose[:, :, :-1] prior_GT_rot = prior_GT_pose[:, :, :-1].transpose(1, 2) prior_compensated_from_GT = inverse_rotate(prior_img, prior_GT_rot, intrinsics, intrinsics_inv) if log_output and epoch == 1: depth_to_show = current_GT_depth[0] output_writers[i].add_image( 'val target Depth {}'.format(j), tensor2array(depth_to_show, max_value=args.max_depth), epoch) output_writers[i].add_image( 'val Input Stabilized', tensor2array(prior_compensated_from_GT[0]), j) prior_compensated_from_prediction = inverse_rotate( prior_img, prior_predicted_rot, intrinsics, intrinsics_inv) predicted_input_pair = torch.cat( [prior_compensated_from_prediction, tgt_img], dim=1) # [B, 6, W, H] GT_input_pair = torch.cat([prior_compensated_from_GT, tgt_img], dim=1) # [B, 6, W, H] # This is the depth from footage stabilized with GT pose, it should be better than depth from raw footage without any GT info raw_depth_stab = depth_net(GT_input_pair) raw_depth_unstab = depth_net(predicted_input_pair) # Upsample depth so that it matches GT size scale_factor = GT_depth.size(-1) // raw_depth_stab.size(-1) depth_stab = F.interpolate(raw_depth_stab, scale_factor=scale_factor, mode='bilinear', align_corners=False) depth_unstab = F.interpolate(raw_depth_unstab, scale_factor=scale_factor, mode='bilinear', align_corners=False) for k, depth in enumerate([depth_stab, depth_unstab]): disparity = 1 / depth errors = stab_depth_errors if k == 0 else unstab_depth_errors errors.update( compute_depth_errors(current_GT_depth, depth, crop=True)) if log_output: prefix = 'stabilized' if k == 0 else 'unstabilized' output_writers[i].add_image( 'val {} Dispnet Output Normalized {}'.format( prefix, j), tensor2array(disparity[0], max_value=None, colormap='bone'), epoch) output_writers[i].add_image( 'val {} Depth Output {}'.format(prefix, j), tensor2array(depth[0], max_value=args.max_depth), epoch) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} ATE Error {:.4f} ({:.4f}), Unstab Rel Abs Error {:.4f} ({:.4f})' .format(batch_time, pose_errors.val[0], pose_errors.avg[0], unstab_depth_errors.val[1], unstab_depth_errors.avg[1])) logger.valid_bar.update(len(val_loader)) errors = (*pose_errors.avg, *unstab_depth_errors.avg, *stab_depth_errors.avg) error_names = (*pose_error_names, *['unstab {}'.format(e) for e in depth_error_names], *['stab {}'.format(e) for e in depth_error_names]) return OrderedDict(zip(error_names, errors))
def validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers=[]): batch_time = AverageMeter() error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] errors = AverageMeter(i=len(error_names)) log_outputs = len(output_writers) > 0 # switch to evaluate mode disp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, depth) in enumerate(val_loader): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) depth = depth.cuda() # compute output output_disp = disp_net(tgt_img_var) output_depth = 1 / output_disp if log_outputs and i % 100 == 0 and i / 100 < len(output_writers): index = int(i // 100) if epoch == 0: output_writers[index].add_image('val Input', tensor2array(tgt_img[0]), 0) depth_to_show = depth[0].cpu() output_writers[index].add_image( 'val target Depth', tensor2array(depth_to_show, max_value=10), epoch) depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1 / depth_to_show).clamp(0, 10) output_writers[index].add_image( 'val target Disparity Normalized', tensor2array(disp_to_show, max_value=None, colormap='bone'), epoch) output_writers[index].add_image( 'val Dispnet Output Normalized', tensor2array(output_disp.data[0].cpu(), max_value=None, colormap='bone'), epoch) output_writers[index].add_image( 'val Depth Output', tensor2array(output_depth.data[0].cpu(), max_value=10), epoch) errors.update(compute_errors(depth, output_depth.data)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} Abs Error {:.4f} ({:.4f})'.format( batch_time, errors.val[0], errors.avg[0])) logger.valid_bar.update(len(val_loader)) return errors.avg, error_names
def test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = None): #data prepared device = global_vars_dict['device'] n_iter_val = global_vars_dict['n_iter_val'] args = global_vars_dict['args'] data_time = AverageMeter() # to eval model disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros(((len(val_loader)-1) * 1 * (args.sequence_length-1),6))#init disp_list = [] flow_list = [] mask_list = [] #3. validation cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in tqdm(enumerate(val_loader)): data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics,intrinsics_inv = intrinsics.to(device),intrinsics_inv.to(device) #3.1 forwardpass #disp disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #pose pose = pose_net(tgt_img, ref_imgs) #flow---- #制作前后一帧的 if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) #FLOW FWD [B,2,H,W] #flow cam :tensor[b,2,h,w] #flow_background flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) #exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, # ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, # ws=args.smooth_loss_weight) rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()#[b,2,h,w] rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs() # mask # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) # 有效区域?4?? # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] end = time.time() #3.4 check log #查看forward pass效果 # 2 disp disp_to_show =tensor2array(disp[0].cpu(), max_value=None,colormap='bone')# tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image('Disp/disp0', disp_to_show,i) disp_list.append(disp_to_show) if i == 0: disp_arr = np.expand_dims(disp_to_show,axis=0) else: disp_to_show = np.expand_dims(disp_to_show,axis=0) disp_arr = np.concatenate([disp_arr,disp_to_show],0) #3. flow tb_writer.add_image('Flow/Flow Output', flow2rgb(flow_fwd[0], max_value=6),i) tb_writer.add_image('Flow/cam_Flow Output', flow2rgb(flow_cam[0], max_value=6),i) tb_writer.add_image('Flow/rigid_Flow Output', flow2rgb(rigidity_mask_fwd[0], max_value=6),i) tb_writer.add_image('Flow/rigidity_mask_fwd',flow2rgb(rigidity_mask_fwd[0],max_value=6),i) flow_list.append(flow2rgb(flow_fwd[0], max_value=6)) #4. mask tb_writer.add_image('Mask /mask0',tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), i) #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch) mask_list.append(tensor2array(explainability_mask[0][0], max_value=None, colormap='magma')) # return disp_list,disp_arr,flow_list,mask_list
def train(args, train_loader, disp_net, pose_net, optimizer, epoch_size, logger, train_writer): global n_iter, device batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.geometry_consistency_weight # switch to train mode disp_net.train() pose_net.train() end = time.time() logger.train_bar.update(0) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): log_losses = i > 0 and n_iter % args.print_freq == 0 # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) # compute output tgt_depth, ref_depths = compute_depth(disp_net, tgt_img, ref_imgs) poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs, intrinsics) #if poses is None: loss_1, loss_3 = compute_photo_and_geometry_loss( tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses, poses_inv, args.num_scales, args.with_ssim, args.with_mask, args.with_auto_mask, args.padding_mode) loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 if log_losses: train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) train_writer.add_scalar('disparity_smoothness_loss', loss_2.item(), n_iter) train_writer.add_scalar('geometry_consistency_loss', loss_3.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( [loss.item(), loss_1.item(), loss_2.item(), loss_3.item()]) logger.train_bar.update(i + 1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def train(train_loader, model, optimizer, epoch, args, log, mpp=None): '''train given model and dataloader''' batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() mixing_avg = [] # switch to train mode model.train() end = time.time() for input, target in train_loader: data_time.update(time.time() - end) optimizer.zero_grad() input = input.cuda() target = target.long().cuda() sc = None # train with clean images if not args.comix: target_reweighted = to_one_hot(target, args.num_classes) output = model(input) loss = bce_loss(softmax(output), target_reweighted) # train with Co-Mixup images else: input_var = Variable(input, requires_grad=True) target_var = Variable(target) A_dist = None # Calculate saliency (unary) if args.clean_lam == 0: model.eval() output = model(input_var) loss_batch = criterion_batch(output, target_var) else: model.train() output = model(input_var) loss_batch = 2 * args.clean_lam * criterion_batch( output, target_var) / args.num_classes loss_batch_mean = torch.mean(loss_batch, dim=0) loss_batch_mean.backward(retain_graph=True) sc = torch.sqrt(torch.mean(input_var.grad**2, dim=1)) # Here, we calculate distance between most salient location (Compatibility) # We can try various measurements with torch.no_grad(): z = F.avg_pool2d(sc, kernel_size=8, stride=1) z_reshape = z.reshape(args.batch_size, -1) z_idx_1d = torch.argmax(z_reshape, dim=1) z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device) z_idx_2d[:, 0] = z_idx_1d // z.shape[-1] z_idx_2d[:, 1] = z_idx_1d % z.shape[-1] A_dist = distance(z_idx_2d, dist_type='l1') if args.clean_lam == 0: model.train() optimizer.zero_grad() # Perform mixup and calculate loss target_reweighted = to_one_hot(target, args.num_classes) if args.parallel: device = input.device out, target_reweighted = mpp(input.cpu(), target_reweighted.cpu(), args=args, sc=sc.cpu(), A_dist=A_dist.cpu()) out = out.to(device) target_reweighted = target_reweighted.to(device) else: out, target_reweighted = mixup_process(input, target_reweighted, args=args, sc=sc, A_dist=A_dist) out = model(out) loss = bce_loss(softmax(out), target_reweighted) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() print_log( '**Train** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f}' .format(top1=top1, top5=top5, error1=100 - top1.avg), log) return top1.avg, top5.avg, losses.avg
def validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers=[]): global device batch_time = AverageMeter() error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] errors = AverageMeter(i=len(error_names)) log_outputs = len(output_writers) > 0 # switch to evaluate mode disp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, depth) in enumerate(val_loader): tgt_img = tgt_img.to(device) depth = depth.to(device) # check gt if depth.nelement() == 0: continue # compute output output_disp = disp_net(tgt_img) output_depth = 1 / output_disp[:, 0] if log_outputs and i < len(output_writers): if epoch == 0: output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0) depth_to_show = depth[0] output_writers[i].add_image( 'val target Depth', tensor2array(depth_to_show, max_value=10), epoch) depth_to_show[depth_to_show == 0] = 1000 disp_to_show = (1 / depth_to_show).clamp(0, 10) output_writers[i].add_image( 'val target Disparity Normalized', tensor2array(disp_to_show, max_value=None, colormap='magma'), epoch) output_writers[i].add_image( 'val Dispnet Output Normalized', tensor2array(output_disp[0], max_value=None, colormap='magma'), epoch) output_writers[i].add_image( 'val Depth Output', tensor2array(output_depth[0], max_value=10), epoch) if depth.nelement() != output_depth.nelement(): b, h, w = depth.size() output_depth = torch.nn.functional.interpolate( output_depth.unsqueeze(1), [h, w]).squeeze(1) errors.update(compute_errors(depth, output_depth, args.dataset)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write( 'valid: Time {} Abs Error {:.4f} ({:.4f})'.format( batch_time, errors.val[0], errors.avg[0])) logger.valid_bar.update(len(val_loader)) return errors.avg, error_names
def main(): # For CUDA multi-processing if args.parallel: import torch.multiprocessing as mp mp.set_start_method("spawn") # Set up the experiment directories if not args.log_off: exp_name = define_exp_name() exp_dir = os.path.join(args.root_dir, exp_name) if not os.path.exists(exp_dir): os.makedirs(exp_dir) copy_script_to_folder(os.path.abspath(__file__), exp_dir) result_png_path = os.path.join(exp_dir, 'results.png') log = open(os.path.join(exp_dir, 'log.txt'.format(args.seed)), 'w') print_log('save path : {}'.format(exp_dir), log) else: log = None exp_dir = None result_png_path = None global best_acc state = {k: v for k, v in args._get_kwargs()} print("") print_log(state, log) print("") print_log("Random Seed: {}".format(args.seed), log) print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) print_log("torch version : {}".format(torch.__version__), log) print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) # Dataloader train_loader, _, _, test_loader, num_classes = load_data_subset( args.batch_size, 0, args.dataset, args.data_dir, labels_per_class=args.labels_per_class, valid_labels_per_class=args.valid_labels_per_class) if args.dataset == 'tiny-imagenet-200': stride = 2 args.mean = torch.tensor([0.5] * 3, dtype=torch.float32).reshape(1, 3, 1, 1).cuda() args.std = torch.tensor([0.5] * 3, dtype=torch.float32).reshape(1, 3, 1, 1).cuda() args.labels_per_class = 500 elif args.dataset == 'cifar10': stride = 1 args.mean = torch.tensor([x / 255 for x in [125.3, 123.0, 113.9]], dtype=torch.float32).reshape(1, 3, 1, 1).cuda() args.std = torch.tensor([x / 255 for x in [63.0, 62.1, 66.7]], dtype=torch.float32).reshape(1, 3, 1, 1).cuda() args.labels_per_class = 5000 elif args.dataset == 'cifar100': stride = 1 args.mean = torch.tensor([x / 255 for x in [129.3, 124.1, 112.4]], dtype=torch.float32).reshape(1, 3, 1, 1).cuda() args.std = torch.tensor([x / 255 for x in [68.2, 65.4, 70.4]], dtype=torch.float32).reshape(1, 3, 1, 1).cuda() args.labels_per_class = 500 else: raise AssertionError('Given Dataset is not supported!') # Create model print_log("=> creating model '{}'".format(args.arch), log) net = models.__dict__[args.arch](num_classes, args.dropout, stride).cuda() args.num_classes = num_classes net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) optimizer = torch.optim.SGD(list(net.parameters()), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) if args.parallel: mpp = MixupProcessParallel(args.m_part, args.batch_size, 1) else: mpp = None recorder = RecorderMeter(args.epochs) # Optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print_log("\n=> loading checkpoint '{}'".format(args.resume), log) checkpoint = torch.load(args.resume) recorder = checkpoint['recorder'] args.start_epoch = checkpoint['epoch'] net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) best_acc = recorder.max_accuracy(False) print_log( "=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch']), log) else: print_log("=> no checkpoint found at '{}'".format(args.resume), log) else: print_log( "=> do not use any checkpoint for {} model".format(args.arch), log) if args.evaluate: validate(test_loader, net, log) return start_time = time.time() epoch_time = AverageMeter() train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(args.start_epoch, args.epochs): current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule) need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \ + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log) # Train for one epoch tr_acc, tr_acc5, tr_los = train(train_loader, net, optimizer, epoch, args, log, mpp) # Evaluate on validation set val_acc, val_los = validate(test_loader, net, log) train_loss.append(tr_los) train_acc.append(tr_acc) test_loss.append(val_los) test_acc.append(val_acc) is_best = False if val_acc > best_acc: is_best = True best_acc = val_acc # Measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() if args.log_off: continue # Save log dummy = recorder.update(epoch, tr_los, tr_acc, val_los, val_acc) if (epoch + 1) % 100 == 0: recorder.plot_curve(result_png_path) train_log = OrderedDict() train_log['train_loss'] = train_loss train_log['train_acc'] = train_acc train_log['test_loss'] = test_loss train_log['test_acc'] = test_acc pickle.dump(train_log, open(os.path.join(exp_dir, 'log.pkl'), 'wb')) plotting(exp_dir) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': net.state_dict(), 'recorder': recorder, 'optimizer': optimizer.state_dict(), }, is_best, exp_dir, 'checkpoint.pth.tar') acc_var = np.maximum( np.max(test_acc[-10:]) - np.median(test_acc[-10:]), np.median(test_acc[-10:]) - np.min(test_acc[-10:])) print_log( "\nfinal 10 epoch acc (median) : {:.2f} (+- {:.2f})".format( np.median(test_acc[-10:]), acc_var), log) if not args.log_off: log.close() if args.parallel: mpp.close()
def validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, tb_writer, torch_device, sample_nb_to_log=2): batch_time = AverageMeter() losses = AverageMeter(i=3, precision=4) log_outputs = sample_nb_to_log > 0 w1, w2, w3, w4 = args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.gt_pose_loss_weight poses = np.zeros( ((len(val_loader) - 1) * args.batch_size * (args.sequence_length - 1), 6)) disp_values = np.zeros(((len(val_loader) - 1) * args.batch_size * 3)) # switch to evaluate mode disp_net.eval() pose_exp_net.eval() end = time.time() logger.valid_bar.update(0) for i, (tgt_img, tgt_lf, ref_imgs, ref_lfs, intrinsics, intrinsics_inv, pose_gt) in enumerate(val_loader): tgt_img = tgt_img.to(torch_device) ref_imgs = [img.to(torch_device) for img in ref_imgs] tgt_lf = tgt_lf.to(torch_device) ref_lfs = [lf.to(torch_device) for lf in ref_lfs] intrinsics = intrinsics.to(torch_device) intrinsics_inv = intrinsics_inv.to(torch_device) pose_gt = pose_gt.to(torch_device) # compute output disp = disp_net(tgt_lf) depth = 1 / disp explainability_mask, pose = pose_exp_net(tgt_lf, ref_lfs) loss_1, warped, diff = photometric_reconstruction_loss( tgt_img, ref_imgs, intrinsics, depth, explainability_mask, pose, args.rotation_mode, args.padding_mode) loss_1 = loss_1.item() if w2 > 0: loss_2 = explainability_loss(explainability_mask).item() else: loss_2 = 0 loss_3 = smooth_loss(depth).item() pred_pose_magnitude = pose[:, :, :3].norm(dim=2) pose_gt_magnitude = pose_gt[:, :, :3].norm(dim=2) pose_loss = (pred_pose_magnitude - pose_gt_magnitude).abs().mean() if log_outputs and i < sample_nb_to_log - 1: # log first output of first batches if epoch == 0: for j, ref in enumerate(ref_imgs): tb_writer.add_image('val/Input {}/{}'.format(j, i), tensor2array(tgt_img[0]), 0) tb_writer.add_image('val/Input {}/{}'.format(j, i), tensor2array(ref[0]), 1) log_output_tensorboard(tb_writer, 'val', i, '', epoch, 1. / disp, disp, warped[0], diff[0], explainability_mask) if log_outputs and i < len(val_loader) - 1: step = args.batch_size * (args.sequence_length - 1) poses[i * step:(i + 1) * step] = pose.cpu().view(-1, 6).numpy() step = args.batch_size * 3 disp_unraveled = disp.cpu().view(args.batch_size, -1) disp_values[i * step:(i + 1) * step] = torch.cat([ disp_unraveled.min(-1)[0], disp_unraveled.median(-1)[0], disp_unraveled.max(-1)[0] ]).numpy() loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * pose_loss losses.update([loss, loss_1, loss_2]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() logger.valid_bar.update(i + 1) if i % args.print_freq == 0: logger.valid_writer.write('valid: Time {} Loss {}'.format( batch_time, losses)) if log_outputs: prefix = 'valid poses' coeffs_names = ['tx', 'ty', 'tz'] if args.rotation_mode == 'euler': coeffs_names.extend(['rx', 'ry', 'rz']) elif args.rotation_mode == 'quat': coeffs_names.extend(['qx', 'qy', 'qz']) for i in range(poses.shape[1]): tb_writer.add_histogram('{} {}'.format(prefix, coeffs_names[i]), poses[:, i], epoch) tb_writer.add_histogram('disp_values', disp_values, epoch) logger.valid_bar.update(len(val_loader)) return losses.avg, [ 'val/total_loss', 'val/photometric_error', 'val/explainability_loss' ]
def main(): global args args = parser.parse_args() normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) flow_loader_h, flow_loader_w = 256, 832 valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) if args.dataset == "kitti2015": val_flow_set = ValidationFlow( root='/home/anuragr/datasets/kitti/kitti2015', sequence_length=5, transform=valid_flow_transform) elif args.dataset == "kitti2012": val_flow_set = ValidationFlowKitti2012( root='/is/ps2/aranjan/AllFlowData/kitti/kitti2012', sequence_length=5, transform=valid_flow_transform) val_flow_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() if args.pretrained_flow: print("=> using pre-trained weights from {}".format( args.pretrained_flow)) weights = torch.load(args.pretrained_flow) flow_net.load_state_dict(weights['state_dict']) #, strict=False) flow_net = flow_net.cuda() flow_net.eval() error_names = ['epe_total', 'epe_non_rigid', 'epe_rigid', 'outliers'] errors = AverageMeter(i=len(error_names)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map) in enumerate(tqdm(val_flow_loader)): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) if args.dataset == "kitti2015": ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] ref_img_var = ref_imgs_var[1:3] elif args.dataset == "kitti2012": ref_img_var = Variable(ref_imgs.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) # compute output flow_fwd, flow_bwd, occ = flow_net(tgt_img_var, ref_img_var) #epe = compute_epe(gt=flow_gt_var, pred=flow_fwd) obj_map_gt_var = Variable(obj_map.cuda(), volatile=True) obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) epe = compute_all_epes(flow_gt_var, flow_fwd, flow_fwd, (1 - obj_map_gt_var_expanded)) #print(i, epe) errors.update(epe) print("Averge EPE", errors.avg)