def depth_occlusion_masks(depth, pose, intrinsics, intrinsics_inv): flow_cam = [ pose2flow(depth.squeeze(), pose[:, i], intrinsics, intrinsics_inv) for i in range(pose.size(1)) ] masks1, masks2 = occlusion_masks(flow_cam[1], flow_cam[2]) masks0, masks3 = occlusion_masks(flow_cam[0], flow_cam[3]) masks = torch.stack((masks0, masks1, masks2, masks3), dim=1) return masks
def main(): global args args = parser.parse_args() args.pretrained_disp = Path(args.pretrained_disp) args.pretrained_pose = Path(args.pretrained_pose) args.pretrained_mask = Path(args.pretrained_mask) args.pretrained_flow = Path(args.pretrained_flow) if args.output_dir is not None: args.output_dir = Path(args.output_dir) args.output_dir.makedirs_p() image_dir = args.output_dir / 'images' gt_dir = args.output_dir / 'gt' mask_dir = args.output_dir / 'mask' viz_dir = args.output_dir / 'viz' rigidity_mask_dir = args.output_dir / 'rigidity' rigidity_census_mask_dir = args.output_dir / 'rigidity_census' explainability_mask_dir = args.output_dir / 'explainability' image_dir.makedirs_p() gt_dir.makedirs_p() mask_dir.makedirs_p() viz_dir.makedirs_p() rigidity_mask_dir.makedirs_p() rigidity_census_mask_dir.makedirs_p() explainability_mask_dir.makedirs_p() output_writer = SummaryWriter(args.output_dir) normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) flow_loader_h, flow_loader_w = 256, 832 valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) val_flow_set = ValidationMask(root=args.kitti_dir, sequence_length=5, transform=valid_flow_transform) val_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) disp_net = getattr(models, args.dispnet)().cuda() pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda() mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda() flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() dispnet_weights = torch.load(args.pretrained_disp) posenet_weights = torch.load(args.pretrained_pose) masknet_weights = torch.load(args.pretrained_mask) flownet_weights = torch.load(args.pretrained_flow) disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) mask_net.load_state_dict(masknet_weights['state_dict']) disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() error_names = ['tp_0', 'fp_0', 'fn_0', 'tp_1', 'fp_1', 'fn_1'] errors = AverageMeter(i=len(error_names)) errors_census = AverageMeter(i=len(error_names)) errors_bare = AverageMeter(i=len(error_names)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt, semantic_map_gt) in enumerate(tqdm(val_loader)): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True) disp = disp_net(tgt_img_var) depth = 1 / disp pose = pose_net(tgt_img_var, ref_imgs_var) explainability_mask = mask_net(tgt_img_var, ref_imgs_var) if args.flownet in ['Back2Future']: flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3]) else: flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var, intrinsics_inv_var) rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * ( 1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).pow(2).sum( dim=1).unsqueeze(1).sqrt() #.normalize() rigidity_mask_census_soft = 1 - rigidity_mask_census_soft / rigidity_mask_census_soft.max( ) rigidity_mask_census = rigidity_mask_census_soft > args.THRESH rigidity_mask_combined = 1 - ( 1 - rigidity_mask.type_as(explainability_mask)) * ( 1 - rigidity_mask_census.type_as(explainability_mask)) flow_fwd_non_rigid = (1 - rigidity_mask_combined).type_as( flow_fwd).expand_as(flow_fwd) * flow_fwd flow_fwd_rigid = rigidity_mask_combined.type_as(flow_fwd).expand_as( flow_fwd) * flow_cam total_flow = flow_fwd_rigid + flow_fwd_non_rigid obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) tgt_img_np = tgt_img[0].numpy() rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy( ) rigidity_mask_census_np = rigidity_mask_census.cpu().data[0].numpy() rigidity_mask_bare_np = rigidity_mask.cpu().data[0].numpy() gt_mask_np = obj_map_gt[0].numpy() semantic_map_np = semantic_map_gt[0].numpy() _errors = mask_error(gt_mask_np, semantic_map_np, rigidity_mask_combined_np[0]) _errors_census = mask_error(gt_mask_np, semantic_map_np, rigidity_mask_census_np[0]) _errors_bare = mask_error(gt_mask_np, semantic_map_np, rigidity_mask_bare_np[0]) errors.update(_errors) errors_census.update(_errors_census) errors_bare.update(_errors_bare) if args.output_dir is not None: np.save(image_dir / str(i).zfill(3), tgt_img_np) np.save(gt_dir / str(i).zfill(3), gt_mask_np) np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np) np.save(rigidity_mask_dir / str(i).zfill(3), rigidity_mask.cpu().data[0].numpy()) np.save(rigidity_census_mask_dir / str(i).zfill(3), rigidity_mask_census.cpu().data[0].numpy()) np.save(explainability_mask_dir / str(i).zfill(3), explainability_mask[:, 1].cpu().data[0].numpy()) # rigidity_mask_dir rigidity_mask.numpy() # rigidity_census_mask_dir rigidity_mask_census.numpy() if (args.output_dir is not None) and i % 10 == 0: ind = int(i // 10) output_writer.add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), ind) output_writer.add_image('val Input', tensor2array(tgt_img[0].cpu()), i) output_writer.add_image( 'val Total Flow Output', flow_to_image(tensor2array(total_flow.data[0].cpu())), ind) output_writer.add_image( 'val Rigid Flow Output', flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Non-rigid Flow Output', flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Rigidity Mask', tensor2array(rigidity_mask.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Census', tensor2array(rigidity_mask_census.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Combined', tensor2array(rigidity_mask_combined.data[0].cpu(), max_value=1, colormap='bone'), ind) if args.output_dir is not None: tgt_img_viz = tensor2array(tgt_img[0].cpu()) depth_viz = tensor2array(disp.data[0].cpu(), max_value=None, colormap='magma') mask_viz = tensor2array(rigidity_mask_census_soft.data[0].cpu(), max_value=1, colormap='bone') row2_viz = flow_to_image( np.hstack((tensor2array(flow_cam.data[0].cpu()), tensor2array(flow_fwd_non_rigid.data[0].cpu()), tensor2array(total_flow.data[0].cpu())))) row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz)) ####### sửa 2 cái vstack thành hstack ############### viz3 = np.hstack( (255 * tgt_img_viz, 255 * depth_viz, 255 * mask_viz, flow_to_image( np.hstack((tensor2array(flow_fwd_non_rigid.data[0].cpu()), tensor2array(total_flow.data[0].cpu())))))) ######################################################## ######## code tự thêm #################### row1_viz = np.transpose(row1_viz, (1, 2, 0)) row2_viz = np.transpose(row2_viz, (1, 2, 0)) viz3 = np.transpose(viz3, (1, 2, 0)) ########################################## row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8')) row2_viz_im = Image.fromarray((row2_viz).astype('uint8')) viz3_im = Image.fromarray(viz3.astype('uint8')) row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png') row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png') viz3_im.save(viz_dir / str(i).zfill(3) + '03.png') bg_iou = errors.sum[0] / (errors.sum[0] + errors.sum[1] + errors.sum[2]) fg_iou = errors.sum[3] / (errors.sum[3] + errors.sum[4] + errors.sum[5]) avg_iou = (bg_iou + fg_iou) / 2 bg_iou_census = errors_census.sum[0] / ( errors_census.sum[0] + errors_census.sum[1] + errors_census.sum[2]) fg_iou_census = errors_census.sum[3] / ( errors_census.sum[3] + errors_census.sum[4] + errors_census.sum[5]) avg_iou_census = (bg_iou_census + fg_iou_census) / 2 bg_iou_bare = errors_bare.sum[0] / ( errors_bare.sum[0] + errors_bare.sum[1] + errors_bare.sum[2]) fg_iou_bare = errors_bare.sum[3] / ( errors_bare.sum[3] + errors_bare.sum[4] + errors_bare.sum[5]) avg_iou_bare = (bg_iou_bare + fg_iou_bare) / 2 print("Results Full Model") print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou')) print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format( avg_iou, bg_iou, fg_iou)) print("Results Census only") print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou')) print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format( avg_iou_census, bg_iou_census, fg_iou_census)) print("Results Bare") print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou')) print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format( avg_iou_bare, bg_iou_bare, fg_iou_bare))
def main(): global args args = parser.parse_args() args.pretrained_disp = Path(args.pretrained_disp) args.pretrained_pose = Path(args.pretrained_pose) args.pretrained_mask = Path(args.pretrained_mask) args.pretrained_flow = Path(args.pretrained_flow) if args.output_dir is not None: args.output_dir = Path(args.output_dir) args.output_dir.makedirs_p() image_dir = args.output_dir / 'images' gt_dir = args.output_dir / 'gt' mask_dir = args.output_dir / 'mask' viz_dir = args.output_dir / 'viz' image_dir.makedirs_p() gt_dir.makedirs_p() mask_dir.makedirs_p() viz_dir.makedirs_p() output_writer = SummaryWriter(args.output_dir) normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) flow_loader_h, flow_loader_w = 256, 832 valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) if args.dataset == "kitti2015": val_flow_set = ValidationFlow(root=args.kitti_dir, sequence_length=5, transform=valid_flow_transform) val_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) disp_net = getattr(models, args.dispnet)().cuda() pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda() mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda() flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() dispnet_weights = torch.load(args.pretrained_disp) posenet_weights = torch.load(args.pretrained_pose) masknet_weights = torch.load(args.pretrained_mask) flownet_weights = torch.load(args.pretrained_flow) disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) mask_net.load_state_dict(masknet_weights['state_dict']) disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() error_names = [ 'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask', 'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask' ] errors = AverageMeter(i=len(error_names)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt) in enumerate(tqdm(val_loader)): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True) disp = disp_net(tgt_img_var) depth = 1 / disp pose = pose_net(tgt_img_var, ref_imgs_var) explainability_mask = mask_net(tgt_img_var, ref_imgs_var) if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3]) else: flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var, intrinsics_inv_var) flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var, intrinsics_inv_var) rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * ( 1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).abs() #.normalize() rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * ( rigidity_mask_census_v).type_as(flow_fwd) rigidity_mask_combined = 1 - ( 1 - rigidity_mask.type_as(explainability_mask)) * ( 1 - rigidity_mask_census.type_as(explainability_mask)) obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as( flow_fwd).expand_as(flow_fwd) * flow_fwd flow_fwd_rigid = (rigidity_mask_combined > args.THRESH ).type_as(flow_cam).expand_as(flow_cam) * flow_cam total_flow = flow_fwd_rigid + flow_fwd_non_rigid rigidity_mask = rigidity_mask.type_as(flow_fwd) _epe_errors = compute_all_epes( flow_gt_var, flow_cam, flow_fwd, rigidity_mask_combined) + compute_all_epes( flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded)) errors.update(_epe_errors) tgt_img_np = tgt_img[0].numpy() rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy( ) gt_mask_np = obj_map_gt[0].numpy() if args.output_dir is not None: np.save(image_dir / str(i).zfill(3), tgt_img_np) np.save(gt_dir / str(i).zfill(3), gt_mask_np) np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np) if (args.output_dir is not None) and i % 10 == 0: ind = int(i // 10) output_writer.add_image( 'val Dispnet Output Normalized', tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone'), ind) output_writer.add_image('val Input', tensor2array(tgt_img[0].cpu()), i) output_writer.add_image( 'val Total Flow Output', flow_to_image(tensor2array(total_flow.data[0].cpu())), ind) output_writer.add_image( 'val Rigid Flow Output', flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Non-rigid Flow Output', flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())), ind) output_writer.add_image( 'val Rigidity Mask', tensor2array(rigidity_mask.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Census', tensor2array(rigidity_mask_census.data[0].cpu(), max_value=1, colormap='bone'), ind) output_writer.add_image( 'val Rigidity Mask Combined', tensor2array(rigidity_mask_combined.data[0].cpu(), max_value=1, colormap='bone'), ind) tgt_img_viz = tensor2array(tgt_img[0].cpu()) depth_viz = tensor2array(disp.data[0].cpu(), max_value=None, colormap='bone') mask_viz = tensor2array( rigidity_mask_census_soft.data[0].prod(dim=0).cpu(), max_value=1, colormap='bone') rigid_flow_viz = flow_to_image(tensor2array( flow_cam.data[0].cpu())) non_rigid_flow_viz = flow_to_image( tensor2array(flow_fwd_non_rigid.data[0].cpu())) total_flow_viz = flow_to_image( tensor2array(total_flow.data[0].cpu())) row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz)) row2_viz = np.hstack( (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz)) row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8')) row2_viz_im = Image.fromarray((row2_viz).astype('uint8')) row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png') row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png') print("Results") print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ". format(*error_names)) print( "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}" .format(*errors.avg))
def train(train_loader, mask_net, pose_net, flow_net, optimizer, epoch_size, train_writer): global args, n_iter w1 = args.smooth_loss_weight w2 = args.mask_loss_weight w3 = args.consensus_loss_weight w4 = args.flow_loss_weight mask_net.train() pose_net.train() flow_net.train() average_loss = 0 for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs, intrinsics, intrinsics_inv, pose_list) in enumerate(tqdm(train_loader)): rgb_tgt_img_var = Variable(rgb_tgt_img.cuda()) # print(rgb_tgt_img_var.size()) rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs] # rgb_ref_imgs_var = [rgb_ref_imgs_var[0], rgb_ref_imgs_var[0], rgb_ref_imgs_var[1], rgb_ref_imgs_var[1]] depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda()) depth_ref_imgs_var = [ Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs ] intrinsics_var = Variable(intrinsics.cuda()) intrinsics_inv_var = Variable(intrinsics_inv.cuda()) # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list] explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var) valid_pixle_mask = torch.where( depth_tgt_img_var == 0, torch.zeros_like(depth_tgt_img_var), torch.ones_like(depth_tgt_img_var)) # zero is invalid # print(depth_test[0].sum()) # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512]) # print() pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var) # generate flow from camera pose and depth flow_fwd, flow_bwd, _ = flow_net(rgb_tgt_img_var, rgb_ref_imgs_var) flows_cam_fwd = pose2flow(depth_ref_imgs_var[1].squeeze(1), pose[:, 1], intrinsics_var, intrinsics_inv_var) flows_cam_bwd = pose2flow(depth_ref_imgs_var[0].squeeze(1), pose[:, 0], intrinsics_var, intrinsics_inv_var) rigidity_mask_fwd = (flows_cam_fwd - flow_fwd[0]).abs() rigidity_mask_bwd = (flows_cam_bwd - flow_bwd[0]).abs() # loss 1: smoothness loss loss1 = smooth_loss(explainability_mask) + smooth_loss( flow_bwd) + smooth_loss(flow_fwd) # loss 2: explainability loss loss2 = explainability_loss(explainability_mask) # loss 3 consensus loss (the mask from networks and the mask from residual) depth_Res_mask, depth_ref_img_warped, depth_diff = depth_residual_mask( valid_pixle_mask, explainability_mask[0], rgb_tgt_img_var, rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var, depth_tgt_img_var, pose) # print(depth_Res_mask[0].size(), explainability_mask[0].size()) loss3 = consensus_loss(explainability_mask[0], rigidity_mask_bwd, rigidity_mask_fwd, args.THRESH, args.wbce) # loss 4: flow loss loss4, flow_ref_img_warped, flow_diff = flow_loss( rgb_tgt_img_var, rgb_ref_imgs_var, [flow_bwd, flow_fwd], explainability_mask) # compute gradient and do Adam step loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4 average_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # visualization in tensorboard if i > 0 and n_iter % args.print_freq == 0: train_writer.add_scalar('smoothness loss', loss1.item(), n_iter) train_writer.add_scalar('explainability loss', loss2.item(), n_iter) train_writer.add_scalar('consensus loss', loss3.item(), n_iter) train_writer.add_scalar('flow loss', loss4.item(), n_iter) train_writer.add_scalar('total loss', loss.item(), n_iter) if n_iter % (args.training_output_freq) == 0: train_writer.add_image('train Input', tensor2array(rgb_tgt_img_var[0]), n_iter) train_writer.add_image( 'train Exp mask Outputs ', tensor2array(explainability_mask[0][0, 0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train depth Res mask ', tensor2array(depth_Res_mask[0][0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train depth ', tensor2array(depth_tgt_img_var[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train valid pixel ', tensor2array(valid_pixle_mask[0].data.cpu(), max_value=1, colormap='bone'), n_iter) train_writer.add_image( 'train after mask', tensor2array(rgb_tgt_img_var[0] * explainability_mask[0][0, 0]), n_iter) train_writer.add_image('train depth diff', tensor2array(depth_diff[0]), n_iter) train_writer.add_image('train flow diff', tensor2array(flow_diff[0]), n_iter) train_writer.add_image('train depth warped img', tensor2array(depth_ref_img_warped[0]), n_iter) train_writer.add_image('train flow warped img', tensor2array(flow_ref_img_warped[0]), n_iter) train_writer.add_image( 'train Cam Flow Output', flow_to_image(tensor2array(flow_fwd[0].data[0].cpu())), n_iter) train_writer.add_image( 'train Flow from Depth Output', flow_to_image(tensor2array(flows_cam_fwd.data[0].cpu())), n_iter) train_writer.add_image( 'train Flow and Depth diff', flow_to_image(tensor2array(rigidity_mask_fwd.data[0].cpu())), n_iter) n_iter += 1 return average_loss / i
def test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = None): #data prepared device = global_vars_dict['device'] n_iter_val = global_vars_dict['n_iter_val'] args = global_vars_dict['args'] data_time = AverageMeter() # to eval model disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros(((len(val_loader)-1) * 1 * (args.sequence_length-1),6))#init disp_list = [] flow_list = [] mask_list = [] #3. validation cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in tqdm(enumerate(val_loader)): data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics,intrinsics_inv = intrinsics.to(device),intrinsics_inv.to(device) #3.1 forwardpass #disp disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #pose pose = pose_net(tgt_img, ref_imgs) #flow---- #制作前后一帧的 if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) #FLOW FWD [B,2,H,W] #flow cam :tensor[b,2,h,w] #flow_background flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) #exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, # ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, # ws=args.smooth_loss_weight) rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()#[b,2,h,w] rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs() # mask # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) # 有效区域?4?? # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] end = time.time() #3.4 check log #查看forward pass效果 # 2 disp disp_to_show =tensor2array(disp[0].cpu(), max_value=None,colormap='bone')# tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image('Disp/disp0', disp_to_show,i) disp_list.append(disp_to_show) if i == 0: disp_arr = np.expand_dims(disp_to_show,axis=0) else: disp_to_show = np.expand_dims(disp_to_show,axis=0) disp_arr = np.concatenate([disp_arr,disp_to_show],0) #3. flow tb_writer.add_image('Flow/Flow Output', flow2rgb(flow_fwd[0], max_value=6),i) tb_writer.add_image('Flow/cam_Flow Output', flow2rgb(flow_cam[0], max_value=6),i) tb_writer.add_image('Flow/rigid_Flow Output', flow2rgb(rigidity_mask_fwd[0], max_value=6),i) tb_writer.add_image('Flow/rigidity_mask_fwd',flow2rgb(rigidity_mask_fwd[0],max_value=6),i) flow_list.append(flow2rgb(flow_fwd[0], max_value=6)) #4. mask tb_writer.add_image('Mask /mask0',tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), i) #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch) mask_list.append(tensor2array(explainability_mask[0][0], max_value=None, colormap='magma')) # return disp_list,disp_arr,flow_list,mask_list
def train(train_loader, disp_net, pose_net, mask_net, flow_net, optimizer, logger=None, train_writer=None, global_vars_dict=None): # 0. 准备 args = global_vars_dict['args'] n_iter = global_vars_dict['n_iter'] device = global_vars_dict['device'] batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight w5 = args.consensus_loss_weight if args.robust: loss_camera = photometric_reconstruction_loss_robust loss_flow = photometric_flow_loss_robust else: loss_camera = photometric_reconstruction_loss loss_flow = photometric_flow_loss #2. switch to train mode disp_net.train() pose_net.train() mask_net.train() flow_net.train() end = time.time() #3. train cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics = intrinsics.to(device) intrinsics_inv = intrinsics_inv.to(device) #3.1 compute output and lossfunc input valve--------------------- #1. disp->depth(none) disparities = disp_net(tgt_img) if args.spatial_normalize: disparities = [spatial_normalize(disp) for disp in disparities] depth = [1 / disp for disp in disparities] #2. pose(none) pose = pose_net(tgt_img, ref_imgs) #pose:[4,4,6] #3.flow_fwd,flow_bwd 全光流 (depth, pose) # 自己改了一点 if args.flownet == 'Back2Future': #临近一共三帧做训练/推断 flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) elif args.flownet == 'FlowNetS': print(' ') # flow_cam 即背景光流 # flow - flow_s = flow_o flow_cam = pose2flow( depth[0].squeeze(), pose[:, 2], intrinsics, intrinsics_inv) # pose[:,2] belongs to forward frame flows_cam_fwd = [ pose2flow(depth_.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) for depth_ in depth ] flows_cam_bwd = [ pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) for depth_ in depth ] exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, ws=args.smooth_loss_weight) rigidity_mask_fwd = [ (flows_cam_fwd_i - flow_fwd_i).abs() for flows_cam_fwd_i, flow_fwd_i in zip(flows_cam_fwd, flow_fwd) ] # .normalize() rigidity_mask_bwd = [ (flows_cam_bwd_i - flow_bwd_i).abs() for flows_cam_bwd_i, flow_bwd_i in zip(flows_cam_bwd, flow_bwd) ] # .normalize() #v_u # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) #有效区域?4?? #list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] #------------------------------------------------- if args.joint_mask_for_depth: explainability_mask_for_depth = compute_joint_mask_for_depth( explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, args.THRESH) else: explainability_mask_for_depth = explainability_mask #explainability_mask_for_depth list(5) [b,2,h/ , w/] if args.no_non_rigid_mask: flow_exp_mask = [None for exp_mask in explainability_mask] if args.DEBUG: print('Using no masks for flow') else: flow_exp_mask = [ 1 - exp_mask[:, 1:3] for exp_mask in explainability_mask ] # explaninbility mask 本来是背景mask, 背景对应像素为1 #取反改成动物mask,并且只要前后两帧 #list(4) [4,2,256,512] #3.2. compute loss重 # E-r minimizes the photometric loss on static scene if w1 > 0: loss_1 = loss_camera(tgt_img, ref_imgs, intrinsics, intrinsics_inv, depth, explainability_mask_for_depth, pose, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_1 = torch.tensor([0.]).to(device) # E_M if w2 > 0: loss_2 = explainability_loss( explainability_mask ) #+ 0.2*gaussian_explainability_loss(explainability_mask) else: loss_2 = 0 # E_S if w3 > 0: if args.smoothness_type == "regular": loss_3 = smooth_loss(depth) + smooth_loss( flow_fwd) + smooth_loss(flow_bwd) + smooth_loss( explainability_mask) elif args.smoothness_type == "edgeaware": loss_3 = edge_aware_smoothness_loss( tgt_img, depth) + edge_aware_smoothness_loss( tgt_img, flow_fwd) loss_3 += edge_aware_smoothness_loss( tgt_img, flow_bwd) + edge_aware_smoothness_loss( tgt_img, explainability_mask) else: loss_3 = torch.tensor([0.]).to(device) # E_F # minimizes photometric loss on moving regions if w4 > 0: loss_4 = loss_flow(tgt_img, ref_imgs[1:3], [flow_bwd, flow_fwd], flow_exp_mask, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_4 = torch.tensor([0.]).to(device) # E_C # drives the collaboration #explainagy_mask:list(6) of [4,4,4,16] rigidity_mask :list(4):[4,2,128,512] if w5 > 0: loss_5 = consensus_depth_flow_mask(explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, exp_masks_target, exp_masks_target, THRESH=args.THRESH, wbce=args.wbce) else: loss_5 = torch.tensor([0.]).to(device) #3.2.6 loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5 #end of loss #3.3 # record loss and EPE losses.update(loss.item(), args.batch_size) # compute gradient and do Adam step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() #3.4 log data # add scalar if args.scalar_freq > 0 and n_iter % args.scalar_freq == 0: train_writer.add_scalar('batch/cam_photometric_error', loss_1.item(), n_iter) if w2 > 0: train_writer.add_scalar('batch/explanability_loss', loss_2.item(), n_iter) train_writer.add_scalar('batch/disparity_smoothness_loss', loss_3.item(), n_iter) train_writer.add_scalar('batch/flow_photometric_error', loss_4.item(), n_iter) train_writer.add_scalar('batch/consensus_error', loss_5.item(), n_iter) train_writer.add_scalar('batch/total_loss', loss.item(), n_iter) # add_image为0 则不输出 if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0: train_writer.add_image('train Input', tensor2array(tgt_img[0]), n_iter) train_writer.add_image( 'train Cam Flow Output', flow_to_image(tensor2array(flow_cam.data[0].cpu())), n_iter) for k, scaled_depth in enumerate(depth): train_writer.add_image( 'train Dispnet Output Normalized111 {}'.format(k), tensor2array(disparities[k].data[0].cpu(), max_value=None, colormap='bone'), n_iter) train_writer.add_image( 'train Depth Output {}'.format(k), tensor2array(1 / disparities[k].data[0].cpu(), max_value=10), n_iter) train_writer.add_image( 'train Non Rigid Flow Output {}'.format(k), flow_to_image(tensor2array(flow_fwd[k].data[0].cpu())), n_iter) train_writer.add_image( 'train Target Rigidity {}'.format(k), tensor2array((rigidity_mask_fwd[k] > args.THRESH).type_as( rigidity_mask_fwd[k]).data[0].cpu(), max_value=1, colormap='bone'), n_iter) b, _, h, w = scaled_depth.size() downscale = tgt_img.size(2) / h tgt_img_scaled = nn.functional.adaptive_avg_pool2d( tgt_img, (h, w)) ref_imgs_scaled = [ nn.functional.adaptive_avg_pool2d(ref_img, (h, w)) for ref_img in ref_imgs ] intrinsics_scaled = torch.cat( (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1) intrinsics_scaled_inv = torch.cat( (intrinsics_inv[:, :, 0:2] * downscale, intrinsics_inv[:, :, 2:]), dim=2) train_writer.add_image( 'train Non Rigid Warped Image {}'.format(k), tensor2array( flow_warp(ref_imgs_scaled[2], flow_fwd[k]).data[0].cpu()), n_iter) # log warped images along with explainability mask for j, ref in enumerate(ref_imgs_scaled): ref_warped = inverse_warp( ref, scaled_depth[:, 0], pose[:, j], intrinsics_scaled, intrinsics_scaled_inv, rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] train_writer.add_image( 'train Warped Outputs {} {}'.format(k, j), tensor2array(ref_warped.data.cpu()), n_iter) train_writer.add_image( 'train Diff Outputs {} {}'.format(k, j), tensor2array( 0.5 * (tgt_img_scaled[0] - ref_warped).abs().data.cpu()), n_iter) if explainability_mask[k] is not None: train_writer.add_image( 'train Exp mask Outputs {} {}'.format(k, j), tensor2array(explainability_mask[k][0, j].data.cpu(), max_value=1, colormap='bone'), n_iter) # csv file write with open(args.save_path / args.log_full, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ loss.item(), loss_1.item(), loss_2.item() if w2 > 0 else 0, loss_3.item(), loss_4.item() ]) #terminal output if args.log_terminal: logger.train_bar.update(i + 1) #当前epoch 进度 if i % args.print_freq == 0: logger.valid_bar_writer.write( 'Train: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) # 3.4 edge conditionsssssssssssssssssssssssss epoch_size = len(train_loader) if i >= epoch_size - 1: break n_iter += 1 global_vars_dict['n_iter'] = n_iter return losses.avg[0] #epoch loss
def main(): global args args = parser.parse_args() args.pretrained_path = Path(args.pretrained_path) if args.output_dir is not None: args.output_dir = Path(args.output_dir) args.output_dir.makedirs_p() image_dir = args.output_dir / 'images' mask_dir = args.output_dir / 'mask' viz_dir = args.output_dir / 'viz' testing_dir = args.output_dir / 'testing' testing_dir_flo = args.output_dir / 'testing_flo' image_dir.makedirs_p() mask_dir.makedirs_p() viz_dir.makedirs_p() testing_dir.makedirs_p() testing_dir_flo.makedirs_p() normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) flow_loader_h, flow_loader_w = 256, 832 valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) val_flow_set = KITTI2015Test(root=args.kitti_dir, sequence_length=5, transform=valid_flow_transform) if args.DEBUG: print("DEBUG MODE: Using Training Set") val_flow_set = KITTI2015Test(root=args.kitti_dir, sequence_length=5, transform=valid_flow_transform, phase='training') val_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) disp_net = getattr(models, args.dispnet)().cuda() pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda() mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda() flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() dispnet_weights = torch.load(args.pretrained_path / 'dispnet_model_best.pth.tar') posenet_weights = torch.load(args.pretrained_path / 'posenet_model_best.pth.tar') masknet_weights = torch.load(args.pretrained_path / 'masknet_model_best.pth.tar') flownet_weights = torch.load(args.pretrained_path / 'flownet_model_best.pth.tar') disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) mask_net.load_state_dict(masknet_weights['state_dict']) disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, tgt_img_original) in enumerate(tqdm(val_loader)): tgt_img_var = Variable(tgt_img.cuda(), volatile=True) ref_imgs_var = [ Variable(img.cuda(), volatile=True) for img in ref_imgs ] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) disp = disp_net(tgt_img_var) depth = 1 / disp pose = pose_net(tgt_img_var, ref_imgs_var) explainability_mask = mask_net(tgt_img_var, ref_imgs_var) if args.flownet == 'Back2Future': flow_fwd, _, _ = flow_net(tgt_img_var, ref_imgs_var[1:3]) else: flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var, intrinsics_inv_var) rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * ( 1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).abs() #.normalize() rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * ( rigidity_mask_census_v).type_as(flow_fwd) rigidity_mask_combined = 1 - ( 1 - rigidity_mask.type_as(explainability_mask)) * ( 1 - rigidity_mask_census.type_as(explainability_mask)) _, _, h_pred, w_pred = flow_cam.size() _, _, h_gt, w_gt = tgt_img_original.size() rigidity_pred_mask = nn.functional.upsample(rigidity_mask_combined, size=(h_pred, w_pred), mode='bilinear') non_rigid_pred = (rigidity_pred_mask <= args.THRESH ).type_as(flow_fwd).expand_as(flow_fwd) * flow_fwd rigid_pred = (rigidity_pred_mask > args.THRESH ).type_as(flow_cam).expand_as(flow_cam) * flow_cam total_pred = non_rigid_pred + rigid_pred pred_fullres = nn.functional.upsample(total_pred, size=(h_gt, w_gt), mode='bilinear') pred_fullres[:, 0, :, :] = pred_fullres[:, 0, :, :] * (w_gt / w_pred) pred_fullres[:, 1, :, :] = pred_fullres[:, 1, :, :] * (h_gt / h_pred) flow_fwd_fullres = nn.functional.upsample(flow_fwd, size=(h_gt, w_gt), mode='bilinear') flow_fwd_fullres[:, 0, :, :] = flow_fwd_fullres[:, 0, :, :] * (w_gt / w_pred) flow_fwd_fullres[:, 1, :, :] = flow_fwd_fullres[:, 1, :, :] * (h_gt / h_pred) flow_cam_fullres = nn.functional.upsample(flow_cam, size=(h_gt, w_gt), mode='bilinear') flow_cam_fullres[:, 0, :, :] = flow_cam_fullres[:, 0, :, :] * (w_gt / w_pred) flow_cam_fullres[:, 1, :, :] = flow_cam_fullres[:, 1, :, :] * (h_gt / h_pred) tgt_img_np = tgt_img[0].numpy() rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy( ) if args.output_dir is not None: np.save(image_dir / str(i).zfill(3), tgt_img_np) np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np) pred_u = pred_fullres[0][0].data.cpu().numpy() pred_v = pred_fullres[0][1].data.cpu().numpy() flow_io.flow_write_png(testing_dir / str(i).zfill(6) + '_10.png', u=pred_u, v=pred_v) flow_io.flow_write(testing_dir_flo / str(i).zfill(6) + '_10.flo', pred_u, pred_v) if (args.output_dir is not None): ind = int(i) tgt_img_viz = tensor2array(tgt_img[0].cpu()) depth_viz = tensor2array(disp.data[0].cpu(), max_value=None, colormap='magma') mask_viz = tensor2array(rigidity_mask_combined.data[0].cpu(), max_value=1, colormap='magma') row2_viz = flow_to_image( np.hstack((tensor2array(flow_cam_fullres.data[0].cpu()), tensor2array(flow_fwd_fullres.data[0].cpu()), tensor2array(pred_fullres.data[0].cpu())))) row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz)) row1_viz_im = Image.fromarray( (255 * row1_viz.transpose(1, 2, 0)).astype('uint8')) row2_viz_im = Image.fromarray( (255 * row2_viz.transpose(1, 2, 0)).astype('uint8')) row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png') row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png') print("Done!")
def validate_without_gt(val_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer, nb_writers, global_vars_dict=None): #data prepared device = global_vars_dict['device'] n_iter_val = global_vars_dict['n_iter_val'] args = global_vars_dict['args'] show_samples = copy.deepcopy(args.show_samples) for i in range(len(show_samples)): show_samples[i] *= len(val_loader) show_samples[i] = show_samples[i] // 1 batch_time = AverageMeter() data_time = AverageMeter() log_outputs = nb_writers > 0 losses = AverageMeter(precision=4) w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight w5 = args.consensus_loss_weight loss_camera = photometric_reconstruction_loss loss_flow = photometric_flow_loss # to eval model disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros( ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6)) #init #3. validation cycle for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): data_time.update(time.time() - end) tgt_img = tgt_img.to(device) ref_imgs = [img.to(device) for img in ref_imgs] intrinsics, intrinsics_inv = intrinsics.to(device), intrinsics_inv.to( device) #3.1 forwardpass #disp disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #pose pose = pose_net(tgt_img, ref_imgs) #[b,3,h,w]; list #flow---- #制作前后一帧的 if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv) flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv) exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img, ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig, ws=args.smooth_loss_weight) no_rigid_flow = flow_fwd - flows_cam_fwd rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs() #[b,2,h,w] rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs() # mask # 4.explainability_mask(none) explainability_mask = mask_net(tgt_img, ref_imgs) # 有效区域?4?? # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63] if args.joint_mask_for_depth: # false explainability_mask_for_depth = explainability_mask #explainability_mask_for_depth = compute_joint_mask_for_depth(explainability_mask, rigidity_mask_bwd, # rigidity_mask_fwd,THRESH=args.THRESH) else: explainability_mask_for_depth = explainability_mask # chage if args.no_non_rigid_mask: flow_exp_mask = None if args.DEBUG: print('Using no masks for flow') else: flow_exp_mask = 1 - explainability_mask[:, 1:3] #3.2loss-compute if w1 > 0: loss_1 = loss_camera(tgt_img, ref_imgs, intrinsics, intrinsics_inv, depth, explainability_mask_for_depth, pose, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_1 = torch.tensor([0.]).to(device) # E_M if w2 > 0: loss_2 = explainability_loss( explainability_mask ) # + 0.2*gaussian_explainability_loss(explainability_mask) else: loss_2 = 0 #if args.smoothness_type == "regular": if w3 > 0: loss_3 = smooth_loss(depth) + smooth_loss( explainability_mask) + smooth_loss(flow_fwd) + smooth_loss( flow_bwd) else: loss_3 = torch.tensor([0.]).to(device) if w4 > 0: loss_4 = loss_flow(tgt_img, ref_imgs[1:3], [flow_bwd, flow_fwd], flow_exp_mask, lambda_oob=args.lambda_oob, qch=args.qch, wssim=args.wssim) else: loss_4 = torch.tensor([0.]).to(device) if w5 > 0: loss_5 = consensus_depth_flow_mask(explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd, exp_masks_target, exp_masks_target, THRESH=args.THRESH, wbce=args.wbce) else: loss_5 = torch.tensor([0.]).to(device) loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5 #3.3 data update losses.update(loss.item(), args.batch_size) batch_time.update(time.time() - end) end = time.time() #3.4 check log #查看forward pass效果 if args.img_freq > 0 and i in show_samples: #output_writers list(3) if epoch == 0: #训练前的validate,目的在于先评估下网络效果 #1.img # 不会执行第二次,注意ref_imgs axis0是batch的索引; axis 1是list(adjacent frame)的索引! tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[0][0]), 0) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[1][0]), 1) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(tgt_img[0]), 2) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[2][0]), 3) tb_writer.add_image( 'epoch 0 Input/sample{}(img{} to img{})'.format( i, i + 1, i + args.sequence_length), tensor2array(ref_imgs[3][0]), 4) depth_to_show = depth[0].cpu( ) # tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image( 'Disp Output/sample{}'.format(i), tensor2array(depth_to_show, max_value=None, colormap='bone'), 0) else: #2.disp depth_to_show = disp[0].cpu( ) # tensor disp_to_show :[1,h,w],0.5~3.1~10 tb_writer.add_image( 'Disp Output/sample{}'.format(i), tensor2array(depth_to_show, max_value=None, colormap='bone'), epoch) #3. flow tb_writer.add_image('Flow/Flow Output sample {}'.format(i), flow2rgb(flow_fwd[0], max_value=6), epoch) tb_writer.add_image('Flow/cam_Flow Output sample {}'.format(i), flow2rgb(flow_cam[0], max_value=6), epoch) tb_writer.add_image( 'Flow/no rigid flow Output sample {}'.format(i), flow2rgb(no_rigid_flow[0], max_value=6), epoch) tb_writer.add_image( 'Flow/rigidity_mask_fwd{}'.format(i), flow2rgb(rigidity_mask_fwd[0], max_value=6), epoch) #4. mask tb_writer.add_image( 'Mask Output/mask0 sample{}'.format(i), tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch) tb_writer.add_image( 'Mask Output/exp_masks_target sample{}'.format(i), tensor2array(exp_masks_target[0][0], max_value=None, colormap='magma'), epoch) #tb_writer.add_image('Mask Output/mask0 sample{}'.format(i), # tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch) # #output_writers[index].add_image('val Depth Output', tensor2array(depth.data[0].cpu(), max_value=10), # epoch) # errors.update(compute_errors(depth, output_depth.data.squeeze(1))) # add scalar if args.scalar_freq > 0 and n_iter_val % args.scalar_freq == 0: tb_writer.add_scalar('val/E_R', loss_1.item(), n_iter_val) if w2 > 0: tb_writer.add_scalar('val/E_M', loss_2.item(), n_iter_val) tb_writer.add_scalar('val/E_S', loss_3.item(), n_iter_val) tb_writer.add_scalar('val/E_F', loss_4.item(), n_iter_val) tb_writer.add_scalar('val/E_C', loss_5.item(), n_iter_val) tb_writer.add_scalar('val/total_loss', loss.item(), n_iter_val) # terminal output if args.log_terminal: logger.valid_bar.update(i + 1) # 当前epoch 进度 if i % args.print_freq == 0: logger.valid_bar_writer.write( 'Valid: Time {} Data {} Loss {}'.format( batch_time, data_time, losses)) n_iter_val += 1 global_vars_dict['n_iter_val'] = n_iter_val return losses.avg[0] #epoch validate loss
def validate_flow_with_gt(val_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, output_writers=[]): global args batch_time = AverageMeter() error_names = [ 'epe_total', 'epe_rigid', 'epe_non_rigid', 'outliers', 'epe_total_with_gt_mask', 'epe_rigid_with_gt_mask', 'epe_non_rigid_with_gt_mask', 'outliers_gt_mask' ] errors = AverageMeter(i=len(error_names)) log_outputs = len(output_writers) > 0 # switch to evaluate mode disp_net.eval() pose_net.eval() mask_net.eval() flow_net.eval() end = time.time() poses = np.zeros( ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6)) for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt) in enumerate(val_loader): tgt_img = Variable(tgt_img.cuda(), volatile=True) ref_imgs = [Variable(img.cuda(), volatile=True) for img in ref_imgs] intrinsics_var = Variable(intrinsics.cuda(), volatile=True) intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True) flow_gt_var = Variable(flow_gt.cuda(), volatile=True) obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True) # compute output------------------------- #1. disp fwd disp = disp_net(tgt_img) if args.spatial_normalize: disp = spatial_normalize(disp) depth = 1 / disp #2. pose fwd pose = pose_net(tgt_img, ref_imgs) #3. mask fwd explainability_mask = mask_net(tgt_img, ref_imgs) #4. flow fwd if args.flownet == 'Back2Future': flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3]) #前一帧,后一阵 elif args.flownet == 'FlowNetC6': flow_fwd = flow_net(tgt_img, ref_imgs[2]) flow_bwd = flow_net(tgt_img, ref_imgs[1]) # compute output------------------------- if args.DEBUG: flow_fwd_x = flow_fwd[:, 0].view(-1).abs().data print("Flow Fwd Median: ", flow_fwd_x.median()) flow_gt_var_x = flow_gt_var[:, 0].view(-1).abs().data print( "Flow GT Median: ", flow_gt_var_x.index_select( 0, flow_gt_var_x.nonzero().view(-1)).median()) flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var, intrinsics_inv_var) oob_rigid = flow2oob(flow_cam) oob_non_rigid = flow2oob(flow_fwd) rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * ( 1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5 rigidity_mask_census_soft = (flow_cam - flow_fwd).abs() #.normalize() rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * ( rigidity_mask_census_v).type_as(flow_fwd) rigidity_mask_combined = 1 - ( 1 - rigidity_mask.type_as(explainability_mask)) * ( 1 - rigidity_mask_census.type_as(explainability_mask)) #get flow flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as( flow_fwd).expand_as(flow_fwd) * flow_fwd flow_fwd_rigid = (rigidity_mask_combined > args.THRESH ).type_as(flow_fwd).expand_as(flow_fwd) * flow_cam total_flow = flow_fwd_rigid + flow_fwd_non_rigid obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd) if log_outputs and i % 10 == 0 and i / 10 < len(output_writers): index = int(i // 10) if epoch == 0: output_writers[index].add_image('val flow Input', tensor2array(tgt_img[0]), 0) flow_to_show = flow_gt[0][:2, :, :].cpu() output_writers[index].add_image( 'val target Flow', flow_to_image(tensor2array(flow_to_show)), epoch) output_writers[index].add_image( 'val Total Flow Output', flow_to_image(tensor2array(total_flow.data[0].cpu())), epoch) output_writers[index].add_image( 'val Rigid Flow Output', flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), epoch) output_writers[index].add_image( 'val Non-rigid Flow Output', flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())), epoch) output_writers[index].add_image( 'val Out of Bound (Rigid)', tensor2array(oob_rigid.type(torch.FloatTensor).data[0].cpu(), max_value=1, colormap='bone'), epoch) output_writers[index].add_scalar( 'val Mean oob (Rigid)', oob_rigid.type(torch.FloatTensor).sum(), epoch) output_writers[index].add_image( 'val Out of Bound (Non-Rigid)', tensor2array(oob_non_rigid.type( torch.FloatTensor).data[0].cpu(), max_value=1, colormap='bone'), epoch) output_writers[index].add_scalar( 'val Mean oob (Non-Rigid)', oob_non_rigid.type(torch.FloatTensor).sum(), epoch) output_writers[index].add_image( 'val Cam Flow Errors', tensor2array(flow_diff(flow_gt_var, flow_cam).data[0].cpu()), epoch) output_writers[index].add_image( 'val Rigidity Mask', tensor2array(rigidity_mask.data[0].cpu(), max_value=1, colormap='bone'), epoch) output_writers[index].add_image( 'val Rigidity Mask Census', tensor2array(rigidity_mask_census.data[0].cpu(), max_value=1, colormap='bone'), epoch) for j, ref in enumerate(ref_imgs): ref_warped = inverse_warp(ref[:1], depth[:1, 0], pose[:1, j], intrinsics_var[:1], intrinsics_inv_var[:1], rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] output_writers[index].add_image( 'val Warped Outputs {}'.format(j), tensor2array(ref_warped.data.cpu()), epoch) output_writers[index].add_image( 'val Diff Outputs {}'.format(j), tensor2array(0.5 * (tgt_img[0] - ref_warped).abs().data.cpu()), epoch) if explainability_mask is not None: output_writers[index].add_image( 'val Exp mask Outputs {}'.format(j), tensor2array(explainability_mask[0, j].data.cpu(), max_value=1, colormap='bone'), epoch) if args.DEBUG: # Check if pose2flow is consistant with inverse warp ref_warped_from_depth = inverse_warp( ref_imgs[2][:1], depth[:1, 0], pose[:1, 2], intrinsics_var[:1], intrinsics_inv_var[:1], rotation_mode=args.rotation_mode, padding_mode=args.padding_mode)[0] ref_warped_from_cam_flow = flow_warp(ref_imgs[2][:1], flow_cam)[0] print( "DEBUG_INFO: Inverse_warp vs pose2flow", torch.mean( torch.abs(ref_warped_from_depth - ref_warped_from_cam_flow)).item()) output_writers[index].add_image( 'val Warped Outputs from Cam Flow', tensor2array(ref_warped_from_cam_flow.data.cpu()), epoch) output_writers[index].add_image( 'val Warped Outputs from inverse warp', tensor2array(ref_warped_from_depth.data.cpu()), epoch) if log_outputs and i < len(val_loader) - 1: step = args.sequence_length - 1 poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1, 6).numpy() if np.isnan(flow_gt.sum().item()) or np.isnan( total_flow.data.sum().item()): print('NaN encountered') # _epe_errors = compute_all_epes( flow_gt_var, flow_cam, flow_fwd, rigidity_mask_combined) + compute_all_epes( flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded)) errors.update(_epe_errors) if args.DEBUG: print("DEBUG_INFO: EPE errors: ", _epe_errors) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if log_outputs: output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch) output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch) output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch) if args.rotation_mode == 'euler': rot_coeffs = ['rx', 'ry', 'rz'] elif args.rotation_mode == 'quat': rot_coeffs = ['qx', 'qy', 'qz'] output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[0]), poses[:, 3], epoch) output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[1]), poses[:, 4], epoch) output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[2]), poses[:, 5], epoch) if args.DEBUG: print("DEBUG_INFO =================>") print("DEBUG_INFO: Average EPE : ", errors.avg) print("DEBUG_INFO =================>") print("DEBUG_INFO =================>") print("DEBUG_INFO =================>") return errors.avg, error_names