def one_scale(cam_flow_fwd, cam_flow_bwd, flow_fwd, flow_bwd, tgt_img, ref_img_fwd, ref_img_bwd, ws): b, _, h, w = cam_flow_fwd.size() tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w)) ref_img_scaled_fwd = nn.functional.adaptive_avg_pool2d(ref_img_fwd, (h, w)) ref_img_scaled_bwd = nn.functional.adaptive_avg_pool2d(ref_img_bwd, (h, w)) cam_warped_im_fwd = flow_warp(ref_img_scaled_fwd, cam_flow_fwd) cam_warped_im_bwd = flow_warp(ref_img_scaled_bwd, cam_flow_bwd) flow_warped_im_fwd = flow_warp(ref_img_scaled_fwd, flow_fwd) flow_warped_im_bwd = flow_warp(ref_img_scaled_bwd, flow_bwd) valid_pixels_cam_fwd = 1 - (cam_warped_im_fwd == 0).prod(1, keepdim=True).type_as(cam_warped_im_fwd) valid_pixels_cam_bwd = 1 - (cam_warped_im_bwd == 0).prod(1, keepdim=True).type_as(cam_warped_im_bwd) valid_pixels_cam = logical_or(valid_pixels_cam_fwd, valid_pixels_cam_bwd) # if one of them is valid, then valid valid_pixels_flow_fwd = 1 - (flow_warped_im_fwd == 0).prod(1, keepdim=True).type_as(flow_warped_im_fwd) valid_pixels_flow_bwd = 1 - (flow_warped_im_bwd == 0).prod(1, keepdim=True).type_as(flow_warped_im_bwd) valid_pixels_flow = logical_or(valid_pixels_flow_fwd, valid_pixels_flow_bwd) # if one of them is valid, then valid cam_err_fwd = ((1-wssim)*robust_l1_per_pix(tgt_img_scaled - cam_warped_im_fwd).mean(1,keepdim=True) \ + wssim*(1 - ssim(tgt_img_scaled, cam_warped_im_fwd)).mean(1, keepdim=True)) cam_err_bwd = ((1-wssim)*robust_l1_per_pix(tgt_img_scaled - cam_warped_im_bwd).mean(1,keepdim=True) \ + wssim*(1 - ssim(tgt_img_scaled, cam_warped_im_bwd)).mean(1, keepdim=True)) cam_err = torch.min(cam_err_fwd, cam_err_bwd) * valid_pixels_cam flow_err = (1-wssim)*robust_l1_per_pix(tgt_img_scaled - flow_warped_im_fwd).mean(1, keepdim=True) \ + wssim*(1 - ssim(tgt_img_scaled, flow_warped_im_fwd)).mean(1, keepdim=True) # flow_err_bwd = (1-wssim)*robust_l1_per_pix(tgt_img_scaled - flow_warped_im_bwd).mean(1, keepdim=True) \ # + wssim*(1 - ssim(tgt_img_scaled, flow_warped_im_bwd)).mean(1, keepdim=True) # flow_err = torch.min(flow_err_fwd, flow_err_bwd) exp_target = (wrig*cam_err <= (flow_err+epsilon)).type_as(cam_err) return exp_target
def one_scale(explainability_mask, occ_masks, flows): assert(explainability_mask is None or flows[0].size()[2:] == explainability_mask.size()[2:]) assert(len(flows) == len(ref_imgs)) reconstruction_loss = 0 b, _, h, w = flows[0].size() downscale = tgt_img.size(2)/h tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w)) ref_imgs_scaled = [nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs] weight = 1. for i, ref_img in enumerate(ref_imgs_scaled): current_flow = flows[i] ref_img_warped = flow_warp(ref_img, current_flow)#fomulate 48 w_c valid_pixels = 1 - (ref_img_warped == 0).prod(1, keepdim=True).type_as(ref_img_warped) diff = (tgt_img_scaled - ref_img_warped) * valid_pixels ssim_loss = 1 - ssim(tgt_img_scaled, ref_img_warped) * valid_pixels oob_normalization_const = valid_pixels.nelement()/valid_pixels.sum() if explainability_mask is not None: diff = diff * explainability_mask[:,i:i+1].expand_as(diff) ssim_loss = ssim_loss * explainability_mask[:,i:i+1].expand_as(ssim_loss) if occ_masks is not None: diff = diff *(1-occ_masks[:,i:i+1]).expand_as(diff) ssim_loss = ssim_loss*(1-occ_masks[:,i:i+1]).expand_as(ssim_loss) reconstruction_loss += (1- wssim)*weight*oob_normalization_const*(robust_l1(diff, q=qch) + wssim*ssim_loss.mean()) + lambda_oob*robust_l1(1 - valid_pixels, q=qch) #weight /= 2.83 assert((reconstruction_loss == reconstruction_loss).item() == 1) return reconstruction_loss
def one_scale(flows): #assert(explainability_mask is None or flows[0].size()[2:] == explainability_mask.size()[2:]) assert (len(flows) == len(ref_imgs)) reconstruction_loss = 0 b, _, h, w = flows[0].size() tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs ] loss = 0.0 for i, ref_img in enumerate(ref_imgs_scaled): current_flow = flows[i] ref_img_warped = flow_warp(ref_img, current_flow) valid_pixels = 1 - (ref_img_warped == 0).prod( 1, keepdim=True).type_as(ref_img_warped) diff = (tgt_img_scaled - ref_img_warped) if wssim: ssim_loss = 1 - ssim(tgt_img_scaled, ref_img_warped) reconstruction_loss = (1 - wssim) * robust_l1_per_pix( diff.mean(1, True), q=qch) * valid_pixels + wssim * ssim_loss.mean(1, True) else: reconstruction_loss = robust_l1_per_pix(diff.mean(1, True), q=qch) * valid_pixels loss += reconstruction_loss.sum() / valid_pixels.sum() return loss
def one_scale(flows): #assert(explainability_mask is None or flows[0].size()[2:] == explainability_mask.size()[2:]) assert (len(flows) == len(ref_imgs)) reconstruction_loss = 0 _, _, h, w = flows[0].size() tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs ] reconstruction_loss_all = 0.0 for i, ref_img in enumerate(ref_imgs_scaled): current_flow = flows[i] ref_img_warped = flow_warp(ref_img, current_flow) valid_pixels = 1 - (ref_img_warped == 0).prod( 1, keepdim=True).type_as(ref_img_warped) reconstruction_loss = gradient_photometric_loss( tgt_img_scaled, ref_img_warped, qch) * valid_pixels[:, :, :-1, :-1] # reconstruction_loss = gradient_photometric_all_direction_loss(tgt_img_scaled, ref_img_warped, qch)*valid_pixels[:,:,1:-1,1:-1] reconstruction_loss_all += reconstruction_loss.sum( ) / valid_pixels[:, :, :-1, :-1].sum() return reconstruction_loss_all
def one_scale(depth, flow_fwd, flow_bwd): b, _, h, w = depth.size() downscale = tgt_img.size(2) / h intrinsics_scaled = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) intrinsics_scaled_inv = torch.cat( (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :, 2:]), dim=2) ref_img_scaled_fwd = nn.functional.adaptive_avg_pool2d( ref_imgs[1], (h, w)) ref_img_scaled_bwd = nn.functional.adaptive_avg_pool2d( ref_imgs[0], (h, w)) depth_warped_im_fwd = inverse_warp(ref_img_scaled_fwd, depth[:, 0], poses[1], intrinsics_scaled, intrinsics_scaled_inv, rotation_mode, padding_mode) depth_warped_im_bwd = inverse_warp(ref_img_scaled_bwd, depth[:, 0], poses[0], intrinsics_scaled, intrinsics_scaled_inv, rotation_mode, padding_mode) valid_pixels_depth_fwd = 1 - (depth_warped_im_fwd == 0).prod( 1, keepdim=True).type_as(depth_warped_im_fwd) valid_pixels_depth_bwd = 1 - (depth_warped_im_bwd == 0).prod( 1, keepdim=True).type_as(depth_warped_im_bwd) valid_pixels_depth = logical_and( valid_pixels_depth_fwd, valid_pixels_depth_bwd) # if one of them is valid, then valid flow_warped_im_fwd = flow_warp(ref_img_scaled_fwd, flow_fwd) flow_warped_im_bwd = flow_warp(ref_img_scaled_bwd, flow_bwd) valid_pixels_flow_fwd = 1 - (flow_warped_im_fwd == 0).prod( 1, keepdim=True).type_as(flow_warped_im_fwd) valid_pixels_flow_bwd = 1 - (flow_warped_im_bwd == 0).prod( 1, keepdim=True).type_as(flow_warped_im_bwd) valid_pixels_flow = logical_and( valid_pixels_flow_fwd, valid_pixels_flow_bwd) # if one of them is valid, then valid valid_pixel = logical_or(valid_pixels_depth, valid_pixels_flow) return valid_pixel
def one_scale(flows): #assert(explainability_mask is None or flows[0].size()[2:] == explainability_mask.size()[2:]) assert (len(flows) == len(ref_imgs)) reconstruction_loss = 0 b, _, h, w = flows[0].size() tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs ] reconstruction_loss_all = [] for i, ref_img in enumerate(ref_imgs_scaled): current_flow = flows[i] ref_img_warped = flow_warp(ref_img, current_flow) # valid_pixels = 1 - (ref_img_warped == 0).prod(1, keepdim=True).type_as(ref_img_warped) diff = (tgt_img_scaled - ref_img_warped) # ssim_loss = 1 - ssim(tgt_img_scaled, ref_img_warped) # if wssim: # reconstruction_loss = (1- wssim)*robust_l1_per_pix(diff.mean(1, True), q=qch) + wssim*ssim_loss.mean(1, True) # else: reconstruction_loss = robust_l1_per_pix(diff.mean(1, True), q=qch) reconstruction_loss_all.append(reconstruction_loss) reconstruction_loss = torch.cat(reconstruction_loss_all, 1) reconstruction_weight = reconstruction_loss # reconstruction_loss_min,_ = reconstruction_loss.min(1,keepdim=True) # reconstruction_loss_min = reconstruction_loss_min.repeat(1,2,1,1) # loss_weight = reconstruction_loss_min/reconstruction_loss # loss_weight = torch.pow(loss_weight,4) loss_weight = 1 - torch.nn.functional.softmax(reconstruction_weight, 1) loss_weight = Variable(loss_weight.data, requires_grad=False) loss = reconstruction_loss * loss_weight # loss = torch.mean(loss,3) # loss = torch.mean(loss,2) # loss = torch.mean(loss,0) return loss.sum() / loss_weight.sum()
def train(train_loader, disp_net, pose_net, mask_net, flow_net, optimizer, logger=None, train_writer=None, global_vars_dict=None): # 0. 准备 args = global_vars_dict['args'] n_iter = global_vars_dict['n_iter'] device = global_vars_dict['device'] batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight w5 = args.consensus_loss_weight if args.robust: loss_camera = photometric_reconstruction_loss_robust loss_flow = photometric_flow_loss_robust else: loss_camera = photometric_reconstruction_loss loss_flow = photometric_flow_loss #2. switch to train mode disp_net.train() pose_net.train() mask_net.train() flow_net.train() end = time.time() #3. train cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) #3.1 compute output and lossfunc input valve--------------------- #1. disp->depth(none) disparities = disp_net(tgt_img) if args.spatial_normalize: disparities = [spatial_normalize(disp) for disp in disparities] depth = [1 / disp for disp in disparities] #2. pose(none) pose = pose_net(tgt_img, ref_imgs) #pose:[4,4,6] #3.flow_fwd,flow_bwd 全光流 (depth, pose) # 自己改了一点 if args.flownet == 'Back2Future': #临近一共三帧做训练/推断 flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) elif args.flownet == 'FlowNetS': print(' ') # flow_cam 即背景光流 # flow - flow_s = flow_o flow_cam = pose2flow( depth[0].squeeze(), pose[:, 2], intrinsics, intrinsics_inv) # pose[:,2] belongs to forward frame flows_cam_fwd = [ pose2flow(depth_.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) for depth_ in depth ] flows_cam_bwd = [ pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) for depth_ in depth ] exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, ws=args.smooth_loss_weight) rigidity_mask_fwd = [ (flows_cam_fwd_i - flow_fwd_i).abs() for flows_cam_fwd_i, flow_fwd_i in zip(flows_cam_fwd, flow_fwd) ] # .normalize() rigidity_mask_bwd = [ (flows_cam_bwd_i - flow_bwd_i).abs() for flows_cam_bwd_i, flow_bwd_i in zip(flows_cam_bwd, flow_bwd) ] # .normalize() #v_u # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) #有效区域?4?? #list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] #------------------------------------------------- if args.joint_mask_for_depth: explainability_mask_for_depth = compute_joint_mask_for_depth( explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, args.THRESH) else: explainability_mask_for_depth = explainability_mask #explainability_mask_for_depth list(5) [b,2,h/ , w/] if args.no_non_rigid_mask: flow_exp_mask = [None for exp_mask in explainability_mask] if args.DEBUG: print('Using no masks for flow') else: flow_exp_mask = [ 1 - exp_mask[:, 1:3] for exp_mask in explainability_mask ] # explaninbility mask 本来是背景mask, 背景对应像素为1 #取反改成动物mask,并且只要前后两帧 #list(4) [4,2,256,512] #3.2. compute loss重 # E-r minimizes the photometric loss on static scene if w1 > 0: loss_1 = loss_camera(tgt_img, ref_imgs, intrinsics, intrinsics_inv, depth, explainability_mask_for_depth, pose, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_1 = torch.tensor([0.]).to(device) # E_M if w2 > 0: loss_2 = explainability_loss( explainability_mask ) #+ 0.2*gaussian_explainability_loss(explainability_mask) else: loss_2 = 0 # E_S if w3 > 0: if args.smoothness_type == "regular": loss_3 = smooth_loss(depth) + smooth_loss( flow_fwd) + smooth_loss(flow_bwd) + smooth_loss( explainability_mask) elif args.smoothness_type == "edgeaware": loss_3 = edge_aware_smoothness_loss( tgt_img, depth) + edge_aware_smoothness_loss( tgt_img, flow_fwd) loss_3 += edge_aware_smoothness_loss( tgt_img, flow_bwd) + edge_aware_smoothness_loss( tgt_img, explainability_mask) else: loss_3 = torch.tensor([0.]).to(device) # E_F # minimizes photometric loss on moving regions if w4 > 0: loss_4 = loss_flow(tgt_img, ref_imgs[1:3], [flow_bwd, flow_fwd], flow_exp_mask, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_4 = torch.tensor([0.]).to(device) # E_C # drives the collaboration #explainagy_mask:list(6) of [4,4,4,16] rigidity_mask :list(4):[4,2,128,512] if w5 > 0: loss_5 = consensus_depth_flow_mask(explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, exp_masks_target, exp_masks_target, THRESH=args.THRESH, wbce=args.wbce) else: loss_5 = torch.tensor([0.]).to(device) #3.2.6 loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5 #end of loss #3.3 # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() #3.4 log data # add scalar if args.scalar_freq > 0 and n_iter % args.scalar_freq == 0: train_writer.add_scalar('batch/cam_photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('batch/explanability_loss', loss_2.item(), n_iter) train_writer.add_scalar('batch/disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('batch/flow_photometric_error', loss_4.item(), n_iter) train_writer.add_scalar('batch/consensus_error', loss_5.item(), n_iter) train_writer.add_scalar('batch/total_loss', loss.item(), n_iter) # add_image为0 则不输出 if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) train_writer.add_image( 'train Cam Flow Output', flow_to_image(tensor2array(flow_cam.data[0].cpu())), n_iter) for k, scaled_depth in enumerate(depth): train_writer.add_image( 'train Dispnet Output Normalized111 {}'.format(k), tensor2array(disparities[k].data[0].cpu(), max_value=None, colormap='bone'), n_iter) train_writer.add_image( 'train Depth Output {}'.format(k), tensor2array(1 / disparities[k].data[0].cpu(), max_value=10), n_iter) train_writer.add_image( 'train Non Rigid Flow Output {}'.format(k), flow_to_image(tensor2array(flow_fwd[k].data[0].cpu())), n_iter) train_writer.add_image( 'train Target Rigidity {}'.format(k), tensor2array((rigidity_mask_fwd[k] > args.THRESH).type_as( rigidity_mask_fwd[k]).data[0].cpu(), max_value=1, colormap='bone'), n_iter) b, _, h, w = scaled_depth.size() downscale = tgt_img.size(2) / h tgt_img_scaled = nn.functional.adaptive_avg_pool2d( tgt_img, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs ] intrinsics_scaled = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) intrinsics_scaled_inv = torch.cat( (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :, 2:]), dim=2) train_writer.add_image( 'train Non Rigid Warped Image {}'.format(k), tensor2array( flow_warp(ref_imgs_scaled[2], flow_fwd[k]).data[0].cpu()), n_iter) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_scaled): ref_warped = inverse_warp( ref, scaled_depth[:, 0], pose[:, j], intrinsics_scaled, intrinsics_scaled_inv, rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] train_writer.add_image( 'train Warped Outputs {} {}'.format(k, j), tensor2array(ref_warped.data.cpu()), n_iter) train_writer.add_image( 'train Diff Outputs {} {}'.format(k, j), tensor2array( 0.5 * (tgt_img_scaled[0] - ref_warped).abs().data.cpu()), n_iter) if explainability_mask[k] is not None: train_writer.add_image( 'train Exp mask Outputs {} {}'.format(k, j), tensor2array(explainability_mask[k][0, j].data.cpu(), max_value=1, colormap='bone'), n_iter) # csv file write with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item(), loss_4.item() ]) #terminal output if args.log_terminal: logger.train_bar.update(i + 1) #当前epoch 进度 if i % args.print_freq == 0: logger.valid_bar_writer.write( 'Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) # 3.4 edge conditionsssssssssssssssssssssssss epoch_size = len(train_loader) if i >= epoch_size - 1: break n_iter += 1 global_vars_dict['n_iter'] = n_iter return losses.avg[0] #epoch loss
def validate_flow_with_gt(val_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, output_writers=[]): global args batch_time = AverageMeter() error_names = [ 'epe_total', 'epe_rigid', 'epe_non_rigid', 'outliers', 'epe_total_with_gt_mask', 'epe_rigid_with_gt_mask', 'epe_non_rigid_with_gt_mask', 'outliers_gt_mask' ] errors = AverageMeter(i=len(error_names)) log_outputs = len(output_writers) > 0 # switch to evaluate mode disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros( ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt) in enumerate(val_loader): tgt_img = Variable(tgt_img.cuda(), volatile=True) ref_imgs = [Variable(img.cuda(), volatile=True) for img in ref_imgs] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True) # compute output------------------------- #1. disp fwd disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #2. pose fwd pose = pose_net(tgt_img, ref_imgs) #3. mask fwd explainability_mask = mask_net(tgt_img, ref_imgs) #4. flow fwd if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) #前一帧,后一阵 elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) # compute output------------------------- if args.DEBUG: flow_fwd_x = flow_fwd[:, 0].view(-1).abs().data print("Flow Fwd Median: ", flow_fwd_x.median()) flow_gt_var_x = flow_gt_var[:, 0].view(-1).abs().data print( "Flow GT Median: ", flow_gt_var_x.index_select( 0, flow_gt_var_x.nonzero().view(-1)).median()) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var, intrinsics_inv_var) oob_rigid = flow2oob(flow_cam) oob_non_rigid = flow2oob(flow_fwd) rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * ( 1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).abs() #.normalize() rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * ( rigidity_mask_census_v).type_as(flow_fwd) rigidity_mask_combined = 1 - ( 1 - rigidity_mask.type_as(explainability_mask)) * ( 1 - rigidity_mask_census.type_as(explainability_mask)) #get flow flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as( flow_fwd).expand_as(flow_fwd) * flow_fwd flow_fwd_rigid = (rigidity_mask_combined > args.THRESH ).type_as(flow_fwd).expand_as(flow_fwd) * flow_cam total_flow = flow_fwd_rigid + flow_fwd_non_rigid obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) if log_outputs and i % 10 == 0 and i / 10 < len(output_writers): index = int(i // 10) if epoch == 0: output_writers[index].add_image('val flow Input', tensor2array(tgt_img[0]), 0) flow_to_show = flow_gt[0][:2, :, :].cpu() output_writers[index].add_image( 'val target Flow', flow_to_image(tensor2array(flow_to_show)), epoch) output_writers[index].add_image( 'val Total Flow Output', flow_to_image(tensor2array(total_flow.data[0].cpu())), epoch) output_writers[index].add_image( 'val Rigid Flow Output', flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), epoch) output_writers[index].add_image( 'val Non-rigid Flow Output', flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())), epoch) output_writers[index].add_image( 'val Out of Bound (Rigid)', tensor2array(oob_rigid.type(torch.FloatTensor).data[0].cpu(), max_value=1, colormap='bone'), epoch) output_writers[index].add_scalar( 'val Mean oob (Rigid)', oob_rigid.type(torch.FloatTensor).sum(), epoch) output_writers[index].add_image( 'val Out of Bound (Non-Rigid)', tensor2array(oob_non_rigid.type( torch.FloatTensor).data[0].cpu(), max_value=1, colormap='bone'), epoch) output_writers[index].add_scalar( 'val Mean oob (Non-Rigid)', oob_non_rigid.type(torch.FloatTensor).sum(), epoch) output_writers[index].add_image( 'val Cam Flow Errors', tensor2array(flow_diff(flow_gt_var, flow_cam).data[0].cpu()), epoch) output_writers[index].add_image( 'val Rigidity Mask', tensor2array(rigidity_mask.data[0].cpu(), max_value=1, colormap='bone'), epoch) output_writers[index].add_image( 'val Rigidity Mask Census', tensor2array(rigidity_mask_census.data[0].cpu(), max_value=1, colormap='bone'), epoch) for j, ref in enumerate(ref_imgs): ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j], intrinsics_var[:1], intrinsics_inv_var[:1], rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] output_writers[index].add_image( 'val Warped Outputs {}'.format(j), tensor2array(ref_warped.data.cpu()), epoch) output_writers[index].add_image( 'val Diff Outputs {}'.format(j), tensor2array(0.5 * (tgt_img[0] - ref_warped).abs().data.cpu()), epoch) if explainability_mask is not None: output_writers[index].add_image( 'val Exp mask Outputs {}'.format(j), tensor2array(explainability_mask[0, j].data.cpu(), max_value=1, colormap='bone'), epoch) if args.DEBUG: # Check if pose2flow is consistant with inverse warp ref_warped_from_depth = inverse_warp( ref_imgs[2][:1], depth[:1, 0], pose[:1, 2], intrinsics_var[:1], intrinsics_inv_var[:1], rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] ref_warped_from_cam_flow = flow_warp(ref_imgs[2][:1], flow_cam)[0] print( "DEBUG_INFO: Inverse_warp vs pose2flow", torch.mean( torch.abs(ref_warped_from_depth - ref_warped_from_cam_flow)).item()) output_writers[index].add_image( 'val Warped Outputs from Cam Flow', tensor2array(ref_warped_from_cam_flow.data.cpu()), epoch) output_writers[index].add_image( 'val Warped Outputs from inverse warp', tensor2array(ref_warped_from_depth.data.cpu()), epoch) if log_outputs and i < len(val_loader) - 1: step = args.sequence_length - 1 poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1, 6).numpy() if np.isnan(flow_gt.sum().item()) or np.isnan( total_flow.data.sum().item()): print('NaN encountered') # _epe_errors = compute_all_epes( flow_gt_var, flow_cam, flow_fwd, rigidity_mask_combined) + compute_all_epes( flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded)) errors.update(_epe_errors) if args.DEBUG: print("DEBUG_INFO: EPE errors: ", _epe_errors) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if log_outputs: output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch) output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch) output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch) if args.rotation_mode == 'euler': rot_coeffs = ['rx', 'ry', 'rz'] elif args.rotation_mode == 'quat': rot_coeffs = ['qx', 'qy', 'qz'] output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[0]), poses[:, 3], epoch) output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[1]), poses[:, 4], epoch) output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[2]), poses[:, 5], epoch) if args.DEBUG: print("DEBUG_INFO =================>") print("DEBUG_INFO: Average EPE : ", errors.avg) print("DEBUG_INFO =================>") print("DEBUG_INFO =================>") print("DEBUG_INFO =================>") return errors.avg, error_names
def train(train_loader, flow_net, optimizer, epoch_size, logger=None, train_writer=None): global args, n_iter batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) # switch to train mode flow_net.train() end = time.time() for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img_var = Variable(tgt_img.cuda()) ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs] if args.flownet == 'Back2Future': flow_fwd, flow_bwd = flow_net(tgt_img_var, ref_imgs_var) else: flow_fwd = flow_net(tgt_img_var, ref_imgs_var[1]) flow_bwd = flow_net(tgt_img_var, ref_imgs_var[0]) loss_smooth = torch.zeros(1).cuda() loss_flow_recon = torch.zeros(1).cuda() loss_velocity_consis = torch.zeros(1).cuda() if args.flow_photo_loss_weight_first: if args.min: loss_flow_recon += args.flow_photo_loss_weight_first*photometric_flow_min_loss(tgt_img_var, ref_imgs_var, [flow_bwd, flow_fwd], lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_flow_recon += args.flow_photo_loss_weight_first*photometric_flow_loss(tgt_img_var, ref_imgs_var, [flow_bwd, flow_fwd], lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) if args.flow_photo_loss_weight_second: if args.min: loss_per, loss_weight= photometric_flow_gradient_min_loss(tgt_img_var, ref_imgs_var, [flow_bwd, flow_fwd], lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) loss_flow_recon += args.flow_photo_loss_weight_second * loss_per else: loss_flow_recon += args.flow_photo_loss_weight_second*photometric_flow_gradient_loss(tgt_img_var, ref_imgs_var, [flow_bwd, flow_fwd], lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) if args.smooth_loss_weight_first: if args.smoothness_type == "regular": loss_smooth += args.smooth_loss_weight_first*(smooth_loss(flow_fwd) + smooth_loss(flow_bwd)) elif args.smoothness_type == "edgeaware": loss_smooth += args.smooth_loss_weight_first*(edge_aware_smoothness_loss(tgt_img_var, flow_fwd)+edge_aware_smoothness_loss(tgt_img_var, flow_bwd)) if args.smooth_loss_weight_second: if args.smoothness_type == "regular": loss_smooth += args.smooth_loss_weight_second*(smooth_loss(flow_fwd) + smooth_loss(flow_bwd)) elif args.smoothness_type == "edgeaware": loss_smooth = args.smooth_loss_weight_second*(edge_aware_smoothness_second_order_loss_change_weight(tgt_img_var, flow_bwd, args.alpha)\ + edge_aware_smoothness_second_order_loss_change_weight(tgt_img_var, flow_fwd, args.alpha)) if args.velocity_consis_loss_weight: loss_velocity_consis = args.velocity_consis_loss_weight*flow_velocity_consis_loss( [flow_bwd, flow_fwd]) loss = loss_smooth + loss_flow_recon + loss_velocity_consis if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('flow_photometric_error', loss_flow_recon.item(), n_iter) train_writer.add_scalar('flow_smoothness_loss', loss_smooth.item(), n_iter) train_writer.add_scalar('velocity_consis_loss', loss_velocity_consis.item(), n_iter) train_writer.add_scalar('total_loss', loss.item(), n_iter) if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) train_writer.add_image('train Flow FWD Output',flow_to_image(tensor2array(flow_fwd[0].data[0].cpu())) , n_iter ) train_writer.add_image('train Flow BWD Output',flow_to_image(tensor2array(flow_bwd[0].data[0].cpu())) , n_iter ) loss_weight_bwd = loss_weight[0][0,0,:,:].unsqueeze(0) loss_weight_fwd = loss_weight[0][0,1,:,:].unsqueeze(0) train_writer.add_image('loss_weight_bwd', tensor2array(loss_weight_bwd.data[0].cpu(), max_value=None, colormap='bone'), n_iter) train_writer.add_image('loss_weight_fwd', tensor2array(loss_weight_fwd.data[0].cpu(), max_value=None, colormap='bone'), n_iter) train_writer.add_image('train Flow FWD error Image',tensor2array(flow_warp(tgt_img_var-ref_imgs_var[1],flow_fwd[0]).data[0].cpu()) , n_iter ) train_writer.add_image('train Flow BWD error Image',tensor2array(flow_warp(tgt_img_var-ref_imgs_var[0],flow_bwd[0]).data[0].cpu()) , n_iter ) # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.log_terminal: logger.train_bar.update(i+1) if i % args.print_freq == 0: logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses)) if i >= epoch_size - 1: break n_iter += 1 return losses.avg[0]
def one_scale(flows): assert (len(flows) == len(ref_imgs)) # reconstruction_loss = 0 b, _, h, w = flows[0].size() tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs ] reconstruction_loss_all = [] reconstruction_weight_all = [] # consistancy_loss_all = [] ssim_loss = 0.0 for i, ref_img in enumerate(ref_imgs_scaled): current_flow = flows[i] ref_img_warped = flow_warp(ref_img, current_flow) diff = (tgt_img_scaled - ref_img_warped) if wssim: ssim_loss += wssim * ( 1 - ssim(tgt_img_scaled, ref_img_warped)).mean() # reconstruction_loss = gradient_photometric_loss(tgt_img_scaled, ref_img_warped, qch) reconstruction_loss = gradient_photometric_all_direction_loss( tgt_img_scaled, ref_img_warped, qch) reconstruction_weight = robust_l1_per_pix(diff.mean(1, True), q=qch) # reconstruction_weight = reconstruction_loss reconstruction_loss_all.append(reconstruction_loss) reconstruction_weight_all.append(reconstruction_weight) # consistancy_loss_all.append(reconstruction_loss) reconstruction_loss = torch.cat(reconstruction_loss_all, 1) reconstruction_weight = torch.cat(reconstruction_weight_all, 1) # consistancy_loss = torch.cat(consistancy_loss_all,1) # reconstruction_weight_min,_ = reconstruction_weight.min(1,keepdim=True) # reconstruction_weight_min = reconstruction_weight_min.repeat(1,2,1,1) # reconstruction_weight_sum = reconstruction_weight.sum(1,keepdim=True) # reconstruction_weight_sum = reconstruction_weight_sum.repeat(1,2,1,1) # consistancy_loss = consistancy_loss[:,0,:,:]-consistancy_loss[:,1,:,:] # consistancy_loss = wconsis*torch.mean(torch.abs(consistancy_loss)) # loss_weight = reconstruction_weight_min/(reconstruction_weight) # loss_weight = reconstruction_weight/reconstruction_weight_sum loss_weight = 1 - torch.nn.functional.softmax(reconstruction_weight, 1) # loss_weight = (loss_weight >= 0.4).type_as(reconstruction_loss) # print(loss_weight.size()) # loss_weight = loss_weight[:,:,:-1,:-1] loss_weight = loss_weight[:, :, 1:-1, 1:-1] # loss_weight = scale_weight(loss_weight,0.3,10) # # loss_weight = torch.pow(loss_weight,4) loss_weight = Variable(loss_weight.data, requires_grad=False) loss = reconstruction_loss * loss_weight # loss, _ = torch.min(reconstruction_loss, dim=1) # # loss = torch.mean(loss,3) # # loss = torch.mean(loss,2) # # loss = torch.mean(loss,0) # loss, _ = torch.min(reconstruction_loss, dim=1) loss = loss.sum() / loss_weight.sum() return loss + ssim_loss, loss_weight
def flow_loss(tgt_img, ref_imgs, flows, explainability_mask, lambda_oob=0, qch=0.5, wssim=0.5): def one_scale(explainability_mask, occ_masks, flows): reconstruction_loss = 0 b, _, h, w = flows[0].size() downscale = tgt_img.size(2) / h tgt_img_scaled = nn.functional.adaptive_avg_pool2d(tgt_img, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs ] weight = 1. for i, ref_img in enumerate(ref_imgs_scaled): current_flow = flows[i] ref_img_warped = flow_warp(ref_img, current_flow) valid_pixels = 1 - (ref_img_warped == 0).prod( 1, keepdim=True).type_as(ref_img_warped) diff = (tgt_img_scaled - ref_img_warped) * valid_pixels ssim_loss = 1 - ssim(tgt_img_scaled, ref_img_warped) * valid_pixels oob_normalization_const = valid_pixels.nelement( ) / valid_pixels.sum() if explainability_mask is not None: diff = diff * explainability_mask[:, i:i + 1].expand_as(diff) ssim_loss = ssim_loss * explainability_mask[:, i:i + 1].expand_as( ssim_loss) if occ_masks is not None: diff = diff * (1 - occ_masks[:, i:i + 1]).expand_as(diff) ssim_loss = ssim_loss * ( 1 - occ_masks[:, i:i + 1]).expand_as(ssim_loss) reconstruction_loss += ( 1 - wssim) * weight * oob_normalization_const * ( robust_l1(diff, q=qch) + wssim * ssim_loss.mean() ) + lambda_oob * robust_l1(1 - valid_pixels, q=qch) #weight /= 2.83 assert ((reconstruction_loss == reconstruction_loss).item() == 1) return reconstruction_loss if type(flows[0]) not in [tuple, list]: if explainability_mask is not None: explainability_mask = [explainability_mask] flows = [[uv] for uv in flows] loss = 0 for i in range(len(flows[0])): flow_at_scale = [uv[i] for uv in flows] occ_mask_at_scale_bw, occ_mask_at_scale_fw = occlusion_masks( flow_at_scale[0], flow_at_scale[1]) occ_mask_at_scale = torch.stack( (occ_mask_at_scale_bw, occ_mask_at_scale_fw), dim=1) # occ_mask_at_scale = None loss += one_scale(explainability_mask[i], occ_mask_at_scale, flow_at_scale) ref_img_warped = flow_warp(ref_imgs[0], flows[0][0]) diff = (tgt_img - ref_img_warped) return loss, ref_img_warped, diff