Beispiel #1
0
def depth_occlusion_masks(depth, pose, intrinsics, intrinsics_inv):
    flow_cam = [
        pose2flow(depth.squeeze(), pose[:, i], intrinsics, intrinsics_inv)
        for i in range(pose.size(1))
    ]
    masks1, masks2 = occlusion_masks(flow_cam[1], flow_cam[2])
    masks0, masks3 = occlusion_masks(flow_cam[0], flow_cam[3])
    masks = torch.stack((masks0, masks1, masks2, masks3), dim=1)
    return masks
Beispiel #2
0
def main():
    global args
    args = parser.parse_args()

    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'
        rigidity_mask_dir = args.output_dir / 'rigidity'
        rigidity_census_mask_dir = args.output_dir / 'rigidity_census'
        explainability_mask_dir = args.output_dir / 'explainability'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()
        rigidity_mask_dir.makedirs_p()
        rigidity_census_mask_dir.makedirs_p()
        explainability_mask_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    val_flow_set = ValidationMask(root=args.kitti_dir,
                                  sequence_length=5,
                                  transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = ['tp_0', 'fp_0', 'fn_0', 'tp_1', 'fp_1', 'fn_1']
    errors = AverageMeter(i=len(error_names))
    errors_census = AverageMeter(i=len(error_names))
    errors_bare = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt,
            semantic_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        if args.flownet in ['Back2Future']:
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).pow(2).sum(
            dim=1).unsqueeze(1).sqrt()  #.normalize()
        rigidity_mask_census_soft = 1 - rigidity_mask_census_soft / rigidity_mask_census_soft.max(
        )
        rigidity_mask_census = rigidity_mask_census_soft > args.THRESH

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        flow_fwd_non_rigid = (1 - rigidity_mask_combined).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = rigidity_mask_combined.type_as(flow_fwd).expand_as(
            flow_fwd) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        rigidity_mask_census_np = rigidity_mask_census.cpu().data[0].numpy()
        rigidity_mask_bare_np = rigidity_mask.cpu().data[0].numpy()

        gt_mask_np = obj_map_gt[0].numpy()
        semantic_map_np = semantic_map_gt[0].numpy()

        _errors = mask_error(gt_mask_np, semantic_map_np,
                             rigidity_mask_combined_np[0])
        _errors_census = mask_error(gt_mask_np, semantic_map_np,
                                    rigidity_mask_census_np[0])
        _errors_bare = mask_error(gt_mask_np, semantic_map_np,
                                  rigidity_mask_bare_np[0])

        errors.update(_errors)
        errors_census.update(_errors_census)
        errors_bare.update(_errors_bare)

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)
            np.save(rigidity_mask_dir / str(i).zfill(3),
                    rigidity_mask.cpu().data[0].numpy())
            np.save(rigidity_census_mask_dir / str(i).zfill(3),
                    rigidity_mask_census.cpu().data[0].numpy())
            np.save(explainability_mask_dir / str(i).zfill(3),
                    explainability_mask[:, 1].cpu().data[0].numpy())
            # rigidity_mask_dir rigidity_mask.numpy()
            # rigidity_census_mask_dir rigidity_mask_census.numpy()

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

        if args.output_dir is not None:
            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='magma')
            mask_viz = tensor2array(rigidity_mask_census_soft.data[0].cpu(),
                                    max_value=1,
                                    colormap='bone')
            row2_viz = flow_to_image(
                np.hstack((tensor2array(flow_cam.data[0].cpu()),
                           tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                           tensor2array(total_flow.data[0].cpu()))))

            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            ####### sửa 2 cái vstack thành hstack ###############
            viz3 = np.hstack(
                (255 * tgt_img_viz, 255 * depth_viz, 255 * mask_viz,
                 flow_to_image(
                     np.hstack((tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                                tensor2array(total_flow.data[0].cpu()))))))
            ########################################################
            ######## code tự thêm ####################
            row1_viz = np.transpose(row1_viz, (1, 2, 0))
            row2_viz = np.transpose(row2_viz, (1, 2, 0))
            viz3 = np.transpose(viz3, (1, 2, 0))
            ##########################################

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))
            viz3_im = Image.fromarray(viz3.astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')
            viz3_im.save(viz_dir / str(i).zfill(3) + '03.png')

    bg_iou = errors.sum[0] / (errors.sum[0] + errors.sum[1] + errors.sum[2])
    fg_iou = errors.sum[3] / (errors.sum[3] + errors.sum[4] + errors.sum[5])
    avg_iou = (bg_iou + fg_iou) / 2

    bg_iou_census = errors_census.sum[0] / (
        errors_census.sum[0] + errors_census.sum[1] + errors_census.sum[2])
    fg_iou_census = errors_census.sum[3] / (
        errors_census.sum[3] + errors_census.sum[4] + errors_census.sum[5])
    avg_iou_census = (bg_iou_census + fg_iou_census) / 2

    bg_iou_bare = errors_bare.sum[0] / (
        errors_bare.sum[0] + errors_bare.sum[1] + errors_bare.sum[2])
    fg_iou_bare = errors_bare.sum[3] / (
        errors_bare.sum[3] + errors_bare.sum[4] + errors_bare.sum[5])
    avg_iou_bare = (bg_iou_bare + fg_iou_bare) / 2

    print("Results Full Model")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou, bg_iou, fg_iou))

    print("Results Census only")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_census, bg_iou_census, fg_iou_census))

    print("Results Bare")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_bare, bg_iou_bare, fg_iou_bare))
Beispiel #3
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    if args.dataset == "kitti2015":
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=5,
                                      transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = [
        'epe_total', 'epe_sp', 'epe_mv', 'Fl', 'epe_total_gt_mask',
        'epe_sp_gt_mask', 'epe_mv_gt_mask', 'Fl_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)

        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)
        flow_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics_var,
                                 intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        rigidity_mask = rigidity_mask.type_as(flow_fwd)
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam,
            flow_fwd, rigidity_mask_combined) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        gt_mask_np = obj_map_gt[0].numpy()

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='bone')
            mask_viz = tensor2array(
                rigidity_mask_census_soft.data[0].prod(dim=0).cpu(),
                max_value=1,
                colormap='bone')
            rigid_flow_viz = flow_to_image(tensor2array(
                flow_cam.data[0].cpu()))
            non_rigid_flow_viz = flow_to_image(
                tensor2array(flow_fwd_non_rigid.data[0].cpu()))
            total_flow_viz = flow_to_image(
                tensor2array(total_flow.data[0].cpu()))
            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            row2_viz = np.hstack(
                (rigid_flow_viz, non_rigid_flow_viz, total_flow_viz))

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')

    print("Results")
    print("\t {:>10}, {:>10}, {:>10}, {:>6}, {:>10}, {:>10}, {:>10}, {:>10} ".
          format(*error_names))
    print(
        "Errors \t {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*errors.avg))
Beispiel #4
0
def train(train_loader, mask_net, pose_net, flow_net, optimizer, epoch_size,
          train_writer):
    global args, n_iter
    w1 = args.smooth_loss_weight
    w2 = args.mask_loss_weight
    w3 = args.consensus_loss_weight
    w4 = args.flow_loss_weight

    mask_net.train()
    pose_net.train()
    flow_net.train()
    average_loss = 0
    for i, (rgb_tgt_img, rgb_ref_imgs, depth_tgt_img, depth_ref_imgs,
            intrinsics, intrinsics_inv,
            pose_list) in enumerate(tqdm(train_loader)):
        rgb_tgt_img_var = Variable(rgb_tgt_img.cuda())
        # print(rgb_tgt_img_var.size())
        rgb_ref_imgs_var = [Variable(img.cuda()) for img in rgb_ref_imgs]
        # rgb_ref_imgs_var = [rgb_ref_imgs_var[0], rgb_ref_imgs_var[0], rgb_ref_imgs_var[1], rgb_ref_imgs_var[1]]
        depth_tgt_img_var = Variable(depth_tgt_img.unsqueeze(1).cuda())
        depth_ref_imgs_var = [
            Variable(img.unsqueeze(1).cuda()) for img in depth_ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        # pose_list_var = [Variable(one_pose.float().cuda()) for one_pose in pose_list]

        explainability_mask = mask_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        valid_pixle_mask = torch.where(
            depth_tgt_img_var == 0, torch.zeros_like(depth_tgt_img_var),
            torch.ones_like(depth_tgt_img_var))  # zero is invalid
        # print(depth_test[0].sum())

        # print(explainability_mask[0].size()) #torch.Size([4, 2, 384, 512])
        # print()
        pose = pose_net(rgb_tgt_img_var, rgb_ref_imgs_var)

        # generate flow from camera pose and depth
        flow_fwd, flow_bwd, _ = flow_net(rgb_tgt_img_var, rgb_ref_imgs_var)
        flows_cam_fwd = pose2flow(depth_ref_imgs_var[1].squeeze(1), pose[:, 1],
                                  intrinsics_var, intrinsics_inv_var)
        flows_cam_bwd = pose2flow(depth_ref_imgs_var[0].squeeze(1), pose[:, 0],
                                  intrinsics_var, intrinsics_inv_var)
        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd[0]).abs()
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd[0]).abs()

        # loss 1: smoothness loss
        loss1 = smooth_loss(explainability_mask) + smooth_loss(
            flow_bwd) + smooth_loss(flow_fwd)

        # loss 2: explainability loss
        loss2 = explainability_loss(explainability_mask)

        # loss 3 consensus loss (the mask from networks and the mask from residual)
        depth_Res_mask, depth_ref_img_warped, depth_diff = depth_residual_mask(
            valid_pixle_mask, explainability_mask[0], rgb_tgt_img_var,
            rgb_ref_imgs_var, intrinsics_var, intrinsics_inv_var,
            depth_tgt_img_var, pose)
        # print(depth_Res_mask[0].size(), explainability_mask[0].size())

        loss3 = consensus_loss(explainability_mask[0], rigidity_mask_bwd,
                               rigidity_mask_fwd, args.THRESH, args.wbce)

        # loss 4: flow loss
        loss4, flow_ref_img_warped, flow_diff = flow_loss(
            rgb_tgt_img_var, rgb_ref_imgs_var, [flow_bwd, flow_fwd],
            explainability_mask)

        # compute gradient and do Adam step
        loss = w1 * loss1 + w2 * loss2 + w3 * loss3 + w4 * loss4
        average_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # visualization in tensorboard
        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('smoothness loss', loss1.item(), n_iter)
            train_writer.add_scalar('explainability loss', loss2.item(),
                                    n_iter)
            train_writer.add_scalar('consensus loss', loss3.item(), n_iter)
            train_writer.add_scalar('flow loss', loss4.item(), n_iter)
            train_writer.add_scalar('total loss', loss.item(), n_iter)
        if n_iter % (args.training_output_freq) == 0:
            train_writer.add_image('train Input',
                                   tensor2array(rgb_tgt_img_var[0]), n_iter)
            train_writer.add_image(
                'train Exp mask Outputs ',
                tensor2array(explainability_mask[0][0, 0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth Res mask ',
                tensor2array(depth_Res_mask[0][0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train depth ',
                tensor2array(depth_tgt_img_var[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train valid pixel ',
                tensor2array(valid_pixle_mask[0].data.cpu(),
                             max_value=1,
                             colormap='bone'), n_iter)
            train_writer.add_image(
                'train after mask',
                tensor2array(rgb_tgt_img_var[0] *
                             explainability_mask[0][0, 0]), n_iter)
            train_writer.add_image('train depth diff',
                                   tensor2array(depth_diff[0]), n_iter)
            train_writer.add_image('train flow diff',
                                   tensor2array(flow_diff[0]), n_iter)
            train_writer.add_image('train depth warped img',
                                   tensor2array(depth_ref_img_warped[0]),
                                   n_iter)
            train_writer.add_image('train flow warped img',
                                   tensor2array(flow_ref_img_warped[0]),
                                   n_iter)
            train_writer.add_image(
                'train Cam Flow Output',
                flow_to_image(tensor2array(flow_fwd[0].data[0].cpu())), n_iter)
            train_writer.add_image(
                'train Flow from Depth Output',
                flow_to_image(tensor2array(flows_cam_fwd.data[0].cpu())),
                n_iter)
            train_writer.add_image(
                'train Flow and Depth diff',
                flow_to_image(tensor2array(rigidity_mask_fwd.data[0].cpu())),
                n_iter)

        n_iter += 1

    return average_loss / i
Beispiel #5
0
def test(val_loader,disp_net,mask_net,pose_net, flow_net, tb_writer,global_vars_dict = None):
#data prepared
    device = global_vars_dict['device']
    n_iter_val = global_vars_dict['n_iter_val']
    args = global_vars_dict['args']


    data_time = AverageMeter()


# to eval model
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()
    poses = np.zeros(((len(val_loader)-1) * 1 * (args.sequence_length-1),6))#init

    disp_list = []

    flow_list = []
    mask_list = []

#3. validation cycle
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in tqdm(enumerate(val_loader)):
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics,intrinsics_inv = intrinsics.to(device),intrinsics_inv.to(device)
    #3.1 forwardpass
        #disp
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)
        depth = 1 / disp

        #pose
        pose = pose_net(tgt_img, ref_imgs)
        #flow----
        #制作前后一帧的
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        #FLOW FWD [B,2,H,W]
        #flow cam :tensor[b,2,h,w]
        #flow_background
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv)

        flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics, intrinsics_inv)
        flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics, intrinsics_inv)

        #exp_masks_target = consensus_exp_masks(flows_cam_fwd, flows_cam_bwd, flow_fwd, flow_bwd, tgt_img,
        #                                       ref_imgs[2], ref_imgs[1], wssim=args.wssim, wrig=args.wrig,
        #                                       ws=args.smooth_loss_weight)

        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()#[b,2,h,w]
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs()

        # mask
        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  # 有效区域?4??

        # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]
        end = time.time()


    #3.4 check log

        #查看forward pass效果
    # 2 disp
        disp_to_show =tensor2array(disp[0].cpu(), max_value=None,colormap='bone')# tensor disp_to_show :[1,h,w],0.5~3.1~10
        tb_writer.add_image('Disp/disp0', disp_to_show,i)
        disp_list.append(disp_to_show)

        if i == 0:
            disp_arr =  np.expand_dims(disp_to_show,axis=0)
        else:
            disp_to_show = np.expand_dims(disp_to_show,axis=0)
            disp_arr = np.concatenate([disp_arr,disp_to_show],0)


    #3. flow
        tb_writer.add_image('Flow/Flow Output', flow2rgb(flow_fwd[0], max_value=6),i)
        tb_writer.add_image('Flow/cam_Flow Output', flow2rgb(flow_cam[0], max_value=6),i)
        tb_writer.add_image('Flow/rigid_Flow Output', flow2rgb(rigidity_mask_fwd[0], max_value=6),i)
        tb_writer.add_image('Flow/rigidity_mask_fwd',flow2rgb(rigidity_mask_fwd[0],max_value=6),i)
        flow_list.append(flow2rgb(flow_fwd[0], max_value=6))
    #4. mask
        tb_writer.add_image('Mask /mask0',tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), i)
        #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch)
        #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch)
        #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch)
        mask_list.append(tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'))
    #

    return disp_list,disp_arr,flow_list,mask_list
Beispiel #6
0
def train(train_loader,
          disp_net,
          pose_net,
          mask_net,
          flow_net,
          optimizer,
          logger=None,
          train_writer=None,
          global_vars_dict=None):
    # 0. 准备
    args = global_vars_dict['args']
    n_iter = global_vars_dict['n_iter']
    device = global_vars_dict['device']

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight
    w5 = args.consensus_loss_weight

    if args.robust:
        loss_camera = photometric_reconstruction_loss_robust
        loss_flow = photometric_flow_loss_robust
    else:
        loss_camera = photometric_reconstruction_loss
        loss_flow = photometric_flow_loss


#2. switch to train mode
    disp_net.train()
    pose_net.train()
    mask_net.train()
    flow_net.train()

    end = time.time()
    #3. train cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        #3.1 compute output and lossfunc input valve---------------------

        #1. disp->depth(none)
        disparities = disp_net(tgt_img)
        if args.spatial_normalize:
            disparities = [spatial_normalize(disp) for disp in disparities]

        depth = [1 / disp for disp in disparities]

        #2. pose(none)
        pose = pose_net(tgt_img, ref_imgs)
        #pose:[4,4,6]

        #3.flow_fwd,flow_bwd 全光流 (depth, pose)
        # 自己改了一点
        if args.flownet == 'Back2Future':  #临近一共三帧做训练/推断
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        elif args.flownet == 'FlowNetS':
            print(' ')

        # flow_cam 即背景光流
        # flow - flow_s = flow_o
        flow_cam = pose2flow(
            depth[0].squeeze(), pose[:, 2], intrinsics,
            intrinsics_inv)  # pose[:,2] belongs to forward frame
        flows_cam_fwd = [
            pose2flow(depth_.squeeze(1), pose[:, 2], intrinsics,
                      intrinsics_inv) for depth_ in depth
        ]
        flows_cam_bwd = [
            pose2flow(depth_.squeeze(1), pose[:, 1], intrinsics,
                      intrinsics_inv) for depth_ in depth
        ]

        exp_masks_target = consensus_exp_masks(flows_cam_fwd,
                                               flows_cam_bwd,
                                               flow_fwd,
                                               flow_bwd,
                                               tgt_img,
                                               ref_imgs[2],
                                               ref_imgs[1],
                                               wssim=args.wssim,
                                               wrig=args.wrig,
                                               ws=args.smooth_loss_weight)
        rigidity_mask_fwd = [
            (flows_cam_fwd_i - flow_fwd_i).abs()
            for flows_cam_fwd_i, flow_fwd_i in zip(flows_cam_fwd, flow_fwd)
        ]  # .normalize()
        rigidity_mask_bwd = [
            (flows_cam_bwd_i - flow_bwd_i).abs()
            for flows_cam_bwd_i, flow_bwd_i in zip(flows_cam_bwd, flow_bwd)
        ]  # .normalize()
        #v_u

        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  #有效区域?4??
        #list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]
        #-------------------------------------------------

        if args.joint_mask_for_depth:
            explainability_mask_for_depth = compute_joint_mask_for_depth(
                explainability_mask, rigidity_mask_bwd, rigidity_mask_fwd,
                args.THRESH)
        else:
            explainability_mask_for_depth = explainability_mask
        #explainability_mask_for_depth list(5) [b,2,h/ , w/]
        if args.no_non_rigid_mask:
            flow_exp_mask = [None for exp_mask in explainability_mask]
            if args.DEBUG:
                print('Using no masks for flow')
        else:
            flow_exp_mask = [
                1 - exp_mask[:, 1:3] for exp_mask in explainability_mask
            ]
            # explaninbility mask 本来是背景mask, 背景对应像素为1
            #取反改成动物mask,并且只要前后两帧
            #list(4) [4,2,256,512]

    #3.2. compute loss重

    # E-r minimizes the photometric loss on static scene
        if w1 > 0:
            loss_1 = loss_camera(tgt_img,
                                 ref_imgs,
                                 intrinsics,
                                 intrinsics_inv,
                                 depth,
                                 explainability_mask_for_depth,
                                 pose,
                                 lambda_oob=args.lambda_oob,
                                 qch=args.qch,
                                 wssim=args.wssim)
        else:
            loss_1 = torch.tensor([0.]).to(device)
        # E_M
        if w2 > 0:
            loss_2 = explainability_loss(
                explainability_mask
            )  #+ 0.2*gaussian_explainability_loss(explainability_mask)
        else:
            loss_2 = 0
        # E_S
        if w3 > 0:
            if args.smoothness_type == "regular":
                loss_3 = smooth_loss(depth) + smooth_loss(
                    flow_fwd) + smooth_loss(flow_bwd) + smooth_loss(
                        explainability_mask)
            elif args.smoothness_type == "edgeaware":
                loss_3 = edge_aware_smoothness_loss(
                    tgt_img, depth) + edge_aware_smoothness_loss(
                        tgt_img, flow_fwd)
                loss_3 += edge_aware_smoothness_loss(
                    tgt_img, flow_bwd) + edge_aware_smoothness_loss(
                        tgt_img, explainability_mask)
        else:
            loss_3 = torch.tensor([0.]).to(device)
        # E_F
        # minimizes photometric loss on moving regions

        if w4 > 0:
            loss_4 = loss_flow(tgt_img,
                               ref_imgs[1:3], [flow_bwd, flow_fwd],
                               flow_exp_mask,
                               lambda_oob=args.lambda_oob,
                               qch=args.qch,
                               wssim=args.wssim)
        else:
            loss_4 = torch.tensor([0.]).to(device)
        # E_C
        # drives the collaboration
        #explainagy_mask:list(6) of [4,4,4,16] rigidity_mask :list(4):[4,2,128,512]
        if w5 > 0:
            loss_5 = consensus_depth_flow_mask(explainability_mask,
                                               rigidity_mask_bwd,
                                               rigidity_mask_fwd,
                                               exp_masks_target,
                                               exp_masks_target,
                                               THRESH=args.THRESH,
                                               wbce=args.wbce)
        else:
            loss_5 = torch.tensor([0.]).to(device)

        #3.2.6
        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5
        #end of loss

        #3.3
        # record loss and EPE
        losses.update(loss.item(), args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #3.4 log data

        # add scalar
        if args.scalar_freq > 0 and n_iter % args.scalar_freq == 0:
            train_writer.add_scalar('batch/cam_photometric_error',
                                    loss_1.item(), n_iter)
            if w2 > 0:
                train_writer.add_scalar('batch/explanability_loss',
                                        loss_2.item(), n_iter)
            train_writer.add_scalar('batch/disparity_smoothness_loss',
                                    loss_3.item(), n_iter)
            train_writer.add_scalar('batch/flow_photometric_error',
                                    loss_4.item(), n_iter)
            train_writer.add_scalar('batch/consensus_error', loss_5.item(),
                                    n_iter)
            train_writer.add_scalar('batch/total_loss', loss.item(), n_iter)

        # add_image为0 则不输出
        if args.training_output_freq > 0 and n_iter % args.training_output_freq == 0:

            train_writer.add_image('train Input', tensor2array(tgt_img[0]),
                                   n_iter)
            train_writer.add_image(
                'train Cam Flow Output',
                flow_to_image(tensor2array(flow_cam.data[0].cpu())), n_iter)

            for k, scaled_depth in enumerate(depth):
                train_writer.add_image(
                    'train Dispnet Output Normalized111 {}'.format(k),
                    tensor2array(disparities[k].data[0].cpu(),
                                 max_value=None,
                                 colormap='bone'), n_iter)
                train_writer.add_image(
                    'train Depth Output {}'.format(k),
                    tensor2array(1 / disparities[k].data[0].cpu(),
                                 max_value=10), n_iter)
                train_writer.add_image(
                    'train Non Rigid Flow Output {}'.format(k),
                    flow_to_image(tensor2array(flow_fwd[k].data[0].cpu())),
                    n_iter)
                train_writer.add_image(
                    'train Target Rigidity {}'.format(k),
                    tensor2array((rigidity_mask_fwd[k] > args.THRESH).type_as(
                        rigidity_mask_fwd[k]).data[0].cpu(),
                                 max_value=1,
                                 colormap='bone'), n_iter)

                b, _, h, w = scaled_depth.size()
                downscale = tgt_img.size(2) / h

                tgt_img_scaled = nn.functional.adaptive_avg_pool2d(
                    tgt_img, (h, w))
                ref_imgs_scaled = [
                    nn.functional.adaptive_avg_pool2d(ref_img, (h, w))
                    for ref_img in ref_imgs
                ]

                intrinsics_scaled = torch.cat(
                    (intrinsics[:, 0:2] / downscale, intrinsics[:, 2:]), dim=1)
                intrinsics_scaled_inv = torch.cat(
                    (intrinsics_inv[:, :, 0:2] * downscale,
                     intrinsics_inv[:, :, 2:]),
                    dim=2)

                train_writer.add_image(
                    'train Non Rigid Warped Image {}'.format(k),
                    tensor2array(
                        flow_warp(ref_imgs_scaled[2],
                                  flow_fwd[k]).data[0].cpu()), n_iter)

                # log warped images along with explainability mask
                for j, ref in enumerate(ref_imgs_scaled):
                    ref_warped = inverse_warp(
                        ref,
                        scaled_depth[:, 0],
                        pose[:, j],
                        intrinsics_scaled,
                        intrinsics_scaled_inv,
                        rotation_mode=args.rotation_mode,
                        padding_mode=args.padding_mode)[0]
                    train_writer.add_image(
                        'train Warped Outputs {} {}'.format(k, j),
                        tensor2array(ref_warped.data.cpu()), n_iter)
                    train_writer.add_image(
                        'train Diff Outputs {} {}'.format(k, j),
                        tensor2array(
                            0.5 *
                            (tgt_img_scaled[0] - ref_warped).abs().data.cpu()),
                        n_iter)
                    if explainability_mask[k] is not None:
                        train_writer.add_image(
                            'train Exp mask Outputs {} {}'.format(k, j),
                            tensor2array(explainability_mask[k][0,
                                                                j].data.cpu(),
                                         max_value=1,
                                         colormap='bone'), n_iter)

        # csv file write
        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                loss.item(),
                loss_1.item(),
                loss_2.item() if w2 > 0 else 0,
                loss_3.item(),
                loss_4.item()
            ])
        #terminal output
        if args.log_terminal:
            logger.train_bar.update(i + 1)  #当前epoch 进度
            if i % args.print_freq == 0:
                logger.valid_bar_writer.write(
                    'Train: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))

    # 3.4 edge conditionsssssssssssssssssssssssss
        epoch_size = len(train_loader)
        if i >= epoch_size - 1:
            break

        n_iter += 1

    global_vars_dict['n_iter'] = n_iter
    return losses.avg[0]  #epoch loss
Beispiel #7
0
def main():
    global args
    args = parser.parse_args()
    args.pretrained_path = Path(args.pretrained_path)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'
        testing_dir = args.output_dir / 'testing'
        testing_dir_flo = args.output_dir / 'testing_flo'

        image_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()
        testing_dir.makedirs_p()
        testing_dir_flo.makedirs_p()

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])

    val_flow_set = KITTI2015Test(root=args.kitti_dir,
                                 sequence_length=5,
                                 transform=valid_flow_transform)

    if args.DEBUG:
        print("DEBUG MODE: Using Training Set")
        val_flow_set = KITTI2015Test(root=args.kitti_dir,
                                     sequence_length=5,
                                     transform=valid_flow_transform,
                                     phase='training')

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_path /
                                 'dispnet_model_best.pth.tar')
    posenet_weights = torch.load(args.pretrained_path /
                                 'posenet_model_best.pth.tar')
    masknet_weights = torch.load(args.pretrained_path /
                                 'masknet_model_best.pth.tar')
    flownet_weights = torch.load(args.pretrained_path /
                                 'flownet_model_best.pth.tar')
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv,
            tgt_img_original) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        if args.flownet == 'Back2Future':
            flow_fwd, _, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5

        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)
        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        _, _, h_pred, w_pred = flow_cam.size()
        _, _, h_gt, w_gt = tgt_img_original.size()
        rigidity_pred_mask = nn.functional.upsample(rigidity_mask_combined,
                                                    size=(h_pred, w_pred),
                                                    mode='bilinear')

        non_rigid_pred = (rigidity_pred_mask <= args.THRESH
                          ).type_as(flow_fwd).expand_as(flow_fwd) * flow_fwd
        rigid_pred = (rigidity_pred_mask > args.THRESH
                      ).type_as(flow_cam).expand_as(flow_cam) * flow_cam
        total_pred = non_rigid_pred + rigid_pred

        pred_fullres = nn.functional.upsample(total_pred,
                                              size=(h_gt, w_gt),
                                              mode='bilinear')
        pred_fullres[:, 0, :, :] = pred_fullres[:, 0, :, :] * (w_gt / w_pred)
        pred_fullres[:, 1, :, :] = pred_fullres[:, 1, :, :] * (h_gt / h_pred)

        flow_fwd_fullres = nn.functional.upsample(flow_fwd,
                                                  size=(h_gt, w_gt),
                                                  mode='bilinear')
        flow_fwd_fullres[:,
                         0, :, :] = flow_fwd_fullres[:,
                                                     0, :, :] * (w_gt / w_pred)
        flow_fwd_fullres[:,
                         1, :, :] = flow_fwd_fullres[:,
                                                     1, :, :] * (h_gt / h_pred)

        flow_cam_fullres = nn.functional.upsample(flow_cam,
                                                  size=(h_gt, w_gt),
                                                  mode='bilinear')
        flow_cam_fullres[:,
                         0, :, :] = flow_cam_fullres[:,
                                                     0, :, :] * (w_gt / w_pred)
        flow_cam_fullres[:,
                         1, :, :] = flow_cam_fullres[:,
                                                     1, :, :] * (h_gt / h_pred)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)
            pred_u = pred_fullres[0][0].data.cpu().numpy()
            pred_v = pred_fullres[0][1].data.cpu().numpy()
            flow_io.flow_write_png(testing_dir / str(i).zfill(6) + '_10.png',
                                   u=pred_u,
                                   v=pred_v)
            flow_io.flow_write(testing_dir_flo / str(i).zfill(6) + '_10.flo',
                               pred_u, pred_v)

        if (args.output_dir is not None):
            ind = int(i)
            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='magma')
            mask_viz = tensor2array(rigidity_mask_combined.data[0].cpu(),
                                    max_value=1,
                                    colormap='magma')
            row2_viz = flow_to_image(
                np.hstack((tensor2array(flow_cam_fullres.data[0].cpu()),
                           tensor2array(flow_fwd_fullres.data[0].cpu()),
                           tensor2array(pred_fullres.data[0].cpu()))))

            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))

            row1_viz_im = Image.fromarray(
                (255 * row1_viz.transpose(1, 2, 0)).astype('uint8'))
            row2_viz_im = Image.fromarray(
                (255 * row2_viz.transpose(1, 2, 0)).astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')

    print("Done!")
Beispiel #8
0
def validate_without_gt(val_loader,
                        disp_net,
                        pose_net,
                        mask_net,
                        flow_net,
                        epoch,
                        logger,
                        tb_writer,
                        nb_writers,
                        global_vars_dict=None):
    #data prepared
    device = global_vars_dict['device']
    n_iter_val = global_vars_dict['n_iter_val']
    args = global_vars_dict['args']
    show_samples = copy.deepcopy(args.show_samples)
    for i in range(len(show_samples)):
        show_samples[i] *= len(val_loader)
        show_samples[i] = show_samples[i] // 1

    batch_time = AverageMeter()
    data_time = AverageMeter()
    log_outputs = nb_writers > 0
    losses = AverageMeter(precision=4)

    w1, w2, w3, w4 = args.cam_photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight, args.flow_photo_loss_weight
    w5 = args.consensus_loss_weight

    loss_camera = photometric_reconstruction_loss
    loss_flow = photometric_flow_loss

    # to eval model
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()
    poses = np.zeros(
        ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6))  #init

    #3. validation cycle
    for i, (tgt_img, ref_imgs, intrinsics,
            intrinsics_inv) in enumerate(val_loader):
        data_time.update(time.time() - end)
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics, intrinsics_inv = intrinsics.to(device), intrinsics_inv.to(
            device)
        #3.1 forwardpass
        #disp
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)
        depth = 1 / disp

        #pose
        pose = pose_net(tgt_img, ref_imgs)  #[b,3,h,w]; list

        #flow----
        #制作前后一帧的
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics,
                             intrinsics_inv)

        flows_cam_fwd = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics,
                                  intrinsics_inv)
        flows_cam_bwd = pose2flow(depth.squeeze(1), pose[:, 1], intrinsics,
                                  intrinsics_inv)

        exp_masks_target = consensus_exp_masks(flows_cam_fwd,
                                               flows_cam_bwd,
                                               flow_fwd,
                                               flow_bwd,
                                               tgt_img,
                                               ref_imgs[2],
                                               ref_imgs[1],
                                               wssim=args.wssim,
                                               wrig=args.wrig,
                                               ws=args.smooth_loss_weight)
        no_rigid_flow = flow_fwd - flows_cam_fwd

        rigidity_mask_fwd = (flows_cam_fwd - flow_fwd).abs()  #[b,2,h,w]
        rigidity_mask_bwd = (flows_cam_bwd - flow_bwd).abs()

        # mask
        # 4.explainability_mask(none)
        explainability_mask = mask_net(tgt_img, ref_imgs)  # 有效区域?4??

        # list(5):item:tensor:[4,4,128,512]...[4,4,4,16] value:[0.33~0.48~0.63]

        if args.joint_mask_for_depth:  # false
            explainability_mask_for_depth = explainability_mask

            #explainability_mask_for_depth = compute_joint_mask_for_depth(explainability_mask, rigidity_mask_bwd,
            #                                                            rigidity_mask_fwd,THRESH=args.THRESH)
        else:
            explainability_mask_for_depth = explainability_mask

        # chage

        if args.no_non_rigid_mask:
            flow_exp_mask = None
            if args.DEBUG:
                print('Using no masks for flow')
        else:
            flow_exp_mask = 1 - explainability_mask[:, 1:3]

        #3.2loss-compute
        if w1 > 0:
            loss_1 = loss_camera(tgt_img,
                                 ref_imgs,
                                 intrinsics,
                                 intrinsics_inv,
                                 depth,
                                 explainability_mask_for_depth,
                                 pose,
                                 lambda_oob=args.lambda_oob,
                                 qch=args.qch,
                                 wssim=args.wssim)
        else:
            loss_1 = torch.tensor([0.]).to(device)

        # E_M
        if w2 > 0:
            loss_2 = explainability_loss(
                explainability_mask
            )  # + 0.2*gaussian_explainability_loss(explainability_mask)
        else:
            loss_2 = 0

        #if args.smoothness_type == "regular":
        if w3 > 0:
            loss_3 = smooth_loss(depth) + smooth_loss(
                explainability_mask) + smooth_loss(flow_fwd) + smooth_loss(
                    flow_bwd)
        else:
            loss_3 = torch.tensor([0.]).to(device)
        if w4 > 0:
            loss_4 = loss_flow(tgt_img,
                               ref_imgs[1:3], [flow_bwd, flow_fwd],
                               flow_exp_mask,
                               lambda_oob=args.lambda_oob,
                               qch=args.qch,
                               wssim=args.wssim)
        else:
            loss_4 = torch.tensor([0.]).to(device)
        if w5 > 0:
            loss_5 = consensus_depth_flow_mask(explainability_mask,
                                               rigidity_mask_bwd,
                                               rigidity_mask_fwd,
                                               exp_masks_target,
                                               exp_masks_target,
                                               THRESH=args.THRESH,
                                               wbce=args.wbce)
        else:
            loss_5 = torch.tensor([0.]).to(device)

        loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + w4 * loss_4 + w5 * loss_5

        #3.3 data update
        losses.update(loss.item(), args.batch_size)
        batch_time.update(time.time() - end)
        end = time.time()

        #3.4 check log

        #查看forward pass效果
        if args.img_freq > 0 and i in show_samples:  #output_writers list(3)
            if epoch == 0:  #训练前的validate,目的在于先评估下网络效果
                #1.img
                # 不会执行第二次,注意ref_imgs axis0是batch的索引; axis 1是list(adjacent frame)的索引!
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[0][0]), 0)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[1][0]), 1)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(tgt_img[0]), 2)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[2][0]), 3)
                tb_writer.add_image(
                    'epoch 0 Input/sample{}(img{} to img{})'.format(
                        i, i + 1, i + args.sequence_length),
                    tensor2array(ref_imgs[3][0]), 4)

                depth_to_show = depth[0].cpu(
                )  # tensor disp_to_show :[1,h,w],0.5~3.1~10
                tb_writer.add_image(
                    'Disp Output/sample{}'.format(i),
                    tensor2array(depth_to_show,
                                 max_value=None,
                                 colormap='bone'), 0)

            else:
                #2.disp
                depth_to_show = disp[0].cpu(
                )  # tensor disp_to_show :[1,h,w],0.5~3.1~10
                tb_writer.add_image(
                    'Disp Output/sample{}'.format(i),
                    tensor2array(depth_to_show,
                                 max_value=None,
                                 colormap='bone'), epoch)
                #3. flow
                tb_writer.add_image('Flow/Flow Output sample {}'.format(i),
                                    flow2rgb(flow_fwd[0], max_value=6), epoch)
                tb_writer.add_image('Flow/cam_Flow Output sample {}'.format(i),
                                    flow2rgb(flow_cam[0], max_value=6), epoch)
                tb_writer.add_image(
                    'Flow/no rigid flow Output sample {}'.format(i),
                    flow2rgb(no_rigid_flow[0], max_value=6), epoch)
                tb_writer.add_image(
                    'Flow/rigidity_mask_fwd{}'.format(i),
                    flow2rgb(rigidity_mask_fwd[0], max_value=6), epoch)

                #4. mask
                tb_writer.add_image(
                    'Mask Output/mask0 sample{}'.format(i),
                    tensor2array(explainability_mask[0][0],
                                 max_value=None,
                                 colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask1 sample{}'.format(i),tensor2array(explainability_mask[1][0], max_value=None, colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask2 sample{}'.format(i),tensor2array(explainability_mask[2][0], max_value=None, colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask3 sample{}'.format(i),tensor2array(explainability_mask[3][0], max_value=None, colormap='magma'), epoch)
                tb_writer.add_image(
                    'Mask Output/exp_masks_target sample{}'.format(i),
                    tensor2array(exp_masks_target[0][0],
                                 max_value=None,
                                 colormap='magma'), epoch)
                #tb_writer.add_image('Mask Output/mask0 sample{}'.format(i),
                #            tensor2array(explainability_mask[0][0], max_value=None, colormap='magma'), epoch)

        #

        #output_writers[index].add_image('val Depth Output', tensor2array(depth.data[0].cpu(), max_value=10),
        #                               epoch)

        # errors.update(compute_errors(depth, output_depth.data.squeeze(1)))
        # add scalar
        if args.scalar_freq > 0 and n_iter_val % args.scalar_freq == 0:
            tb_writer.add_scalar('val/E_R', loss_1.item(), n_iter_val)
            if w2 > 0:
                tb_writer.add_scalar('val/E_M', loss_2.item(), n_iter_val)
            tb_writer.add_scalar('val/E_S', loss_3.item(), n_iter_val)
            tb_writer.add_scalar('val/E_F', loss_4.item(), n_iter_val)
            tb_writer.add_scalar('val/E_C', loss_5.item(), n_iter_val)
            tb_writer.add_scalar('val/total_loss', loss.item(), n_iter_val)

        # terminal output
        if args.log_terminal:
            logger.valid_bar.update(i + 1)  # 当前epoch 进度
            if i % args.print_freq == 0:
                logger.valid_bar_writer.write(
                    'Valid: Time {} Data {} Loss {}'.format(
                        batch_time, data_time, losses))

        n_iter_val += 1

    global_vars_dict['n_iter_val'] = n_iter_val
    return losses.avg[0]  #epoch validate loss
Beispiel #9
0
def validate_flow_with_gt(val_loader,
                          disp_net,
                          pose_net,
                          mask_net,
                          flow_net,
                          epoch,
                          logger,
                          output_writers=[]):
    global args
    batch_time = AverageMeter()
    error_names = [
        'epe_total', 'epe_rigid', 'epe_non_rigid', 'outliers',
        'epe_total_with_gt_mask', 'epe_rigid_with_gt_mask',
        'epe_non_rigid_with_gt_mask', 'outliers_gt_mask'
    ]
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    end = time.time()

    poses = np.zeros(
        ((len(val_loader) - 1) * 1 * (args.sequence_length - 1), 6))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt,
            obj_map_gt) in enumerate(val_loader):
        tgt_img = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs = [Variable(img.cuda(), volatile=True) for img in ref_imgs]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        # compute output-------------------------

        #1. disp fwd
        disp = disp_net(tgt_img)
        if args.spatial_normalize:
            disp = spatial_normalize(disp)

        depth = 1 / disp

        #2. pose fwd
        pose = pose_net(tgt_img, ref_imgs)

        #3. mask fwd
        explainability_mask = mask_net(tgt_img, ref_imgs)

        #4. flow fwd
        if args.flownet == 'Back2Future':
            flow_fwd, flow_bwd, _ = flow_net(tgt_img, ref_imgs[1:3])  #前一帧,后一阵
        elif args.flownet == 'FlowNetC6':
            flow_fwd = flow_net(tgt_img, ref_imgs[2])
            flow_bwd = flow_net(tgt_img, ref_imgs[1])
        # compute output-------------------------

        if args.DEBUG:
            flow_fwd_x = flow_fwd[:, 0].view(-1).abs().data
            print("Flow Fwd Median: ", flow_fwd_x.median())
            flow_gt_var_x = flow_gt_var[:, 0].view(-1).abs().data
            print(
                "Flow GT Median: ",
                flow_gt_var_x.index_select(
                    0,
                    flow_gt_var_x.nonzero().view(-1)).median())

        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)
        oob_rigid = flow2oob(flow_cam)
        oob_non_rigid = flow2oob(flow_fwd)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5

        rigidity_mask_census_soft = (flow_cam - flow_fwd).abs()  #.normalize()
        rigidity_mask_census_u = rigidity_mask_census_soft[:, 0] < args.THRESH
        rigidity_mask_census_v = rigidity_mask_census_soft[:, 1] < args.THRESH
        rigidity_mask_census = (rigidity_mask_census_u).type_as(flow_fwd) * (
            rigidity_mask_census_v).type_as(flow_fwd)

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        #get flow
        flow_fwd_non_rigid = (rigidity_mask_combined <= args.THRESH).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = (rigidity_mask_combined > args.THRESH
                          ).type_as(flow_fwd).expand_as(flow_fwd) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        if log_outputs and i % 10 == 0 and i / 10 < len(output_writers):
            index = int(i // 10)
            if epoch == 0:
                output_writers[index].add_image('val flow Input',
                                                tensor2array(tgt_img[0]), 0)
                flow_to_show = flow_gt[0][:2, :, :].cpu()
                output_writers[index].add_image(
                    'val target Flow',
                    flow_to_image(tensor2array(flow_to_show)), epoch)

            output_writers[index].add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), epoch)
            output_writers[index].add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())),
                epoch)
            output_writers[index].add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                epoch)
            output_writers[index].add_image(
                'val Out of Bound (Rigid)',
                tensor2array(oob_rigid.type(torch.FloatTensor).data[0].cpu(),
                             max_value=1,
                             colormap='bone'), epoch)
            output_writers[index].add_scalar(
                'val Mean oob (Rigid)',
                oob_rigid.type(torch.FloatTensor).sum(), epoch)
            output_writers[index].add_image(
                'val Out of Bound (Non-Rigid)',
                tensor2array(oob_non_rigid.type(
                    torch.FloatTensor).data[0].cpu(),
                             max_value=1,
                             colormap='bone'), epoch)
            output_writers[index].add_scalar(
                'val Mean oob (Non-Rigid)',
                oob_non_rigid.type(torch.FloatTensor).sum(), epoch)
            output_writers[index].add_image(
                'val Cam Flow Errors',
                tensor2array(flow_diff(flow_gt_var, flow_cam).data[0].cpu()),
                epoch)
            output_writers[index].add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), epoch)
            output_writers[index].add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), epoch)

            for j, ref in enumerate(ref_imgs):
                ref_warped = inverse_warp(ref[:1],
                                          depth[:1, 0],
                                          pose[:1, j],
                                          intrinsics_var[:1],
                                          intrinsics_inv_var[:1],
                                          rotation_mode=args.rotation_mode,
                                          padding_mode=args.padding_mode)[0]

                output_writers[index].add_image(
                    'val Warped Outputs {}'.format(j),
                    tensor2array(ref_warped.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Diff Outputs {}'.format(j),
                    tensor2array(0.5 *
                                 (tgt_img[0] - ref_warped).abs().data.cpu()),
                    epoch)
                if explainability_mask is not None:
                    output_writers[index].add_image(
                        'val Exp mask Outputs {}'.format(j),
                        tensor2array(explainability_mask[0, j].data.cpu(),
                                     max_value=1,
                                     colormap='bone'), epoch)

            if args.DEBUG:
                # Check if pose2flow is consistant with inverse warp
                ref_warped_from_depth = inverse_warp(
                    ref_imgs[2][:1],
                    depth[:1, 0],
                    pose[:1, 2],
                    intrinsics_var[:1],
                    intrinsics_inv_var[:1],
                    rotation_mode=args.rotation_mode,
                    padding_mode=args.padding_mode)[0]
                ref_warped_from_cam_flow = flow_warp(ref_imgs[2][:1],
                                                     flow_cam)[0]
                print(
                    "DEBUG_INFO: Inverse_warp vs pose2flow",
                    torch.mean(
                        torch.abs(ref_warped_from_depth -
                                  ref_warped_from_cam_flow)).item())
                output_writers[index].add_image(
                    'val Warped Outputs from Cam Flow',
                    tensor2array(ref_warped_from_cam_flow.data.cpu()), epoch)
                output_writers[index].add_image(
                    'val Warped Outputs from inverse warp',
                    tensor2array(ref_warped_from_depth.data.cpu()), epoch)

        if log_outputs and i < len(val_loader) - 1:
            step = args.sequence_length - 1
            poses[i * step:(i + 1) * step] = pose.data.cpu().view(-1,
                                                                  6).numpy()

        if np.isnan(flow_gt.sum().item()) or np.isnan(
                total_flow.data.sum().item()):
            print('NaN encountered')
        #
        _epe_errors = compute_all_epes(
            flow_gt_var, flow_cam,
            flow_fwd, rigidity_mask_combined) + compute_all_epes(
                flow_gt_var, flow_cam, flow_fwd, (1 - obj_map_gt_var_expanded))
        errors.update(_epe_errors)

        if args.DEBUG:
            print("DEBUG_INFO: EPE errors: ", _epe_errors)
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    if log_outputs:
        output_writers[0].add_histogram('val poses_tx', poses[:, 0], epoch)
        output_writers[0].add_histogram('val poses_ty', poses[:, 1], epoch)
        output_writers[0].add_histogram('val poses_tz', poses[:, 2], epoch)
        if args.rotation_mode == 'euler':
            rot_coeffs = ['rx', 'ry', 'rz']
        elif args.rotation_mode == 'quat':
            rot_coeffs = ['qx', 'qy', 'qz']
        output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[0]),
                                        poses[:, 3], epoch)
        output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[1]),
                                        poses[:, 4], epoch)
        output_writers[0].add_histogram('val poses_{}'.format(rot_coeffs[2]),
                                        poses[:, 5], epoch)

    if args.DEBUG:
        print("DEBUG_INFO =================>")
        print("DEBUG_INFO: Average EPE : ", errors.avg)
        print("DEBUG_INFO =================>")
        print("DEBUG_INFO =================>")
        print("DEBUG_INFO =================>")

    return errors.avg, error_names