def main():
    global best_error, worst_error
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    args = parser.parse_args()
    if args.gt_type == 'KITTI':
        from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework
    elif args.gt_type == 'stillbox':
        from stillbox_eval.depth_evaluation_utils import test_framework_stillbox as test_framework

    weights = torch.load(args.pretrained_depthnet)
    depth_net = DepthNet(depth_activation="elu", batch_norm='bn' in weights.keys() and weights['bn']).to(device)

    depth_net.load_state_dict(weights['state_dict'])
    depth_net.eval()

    if args.pretrained_posenet is None:
        args.stabilize_from_GT = True
        print('no PoseNet specified, stab will be done from ground truth')
        seq_length = 5
    else:
        weights = torch.load(args.pretrained_posenet)
        seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3)
        pose_net = PoseNet(seq_length=seq_length).to(device)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    dataset_dir = Path(args.dataset_dir)
    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = list(f.read().splitlines())
    else:
        test_files = [file.relpathto(dataset_dir) for file in sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])]

    framework = test_framework(dataset_dir, test_files, seq_length, args.min_depth, args.max_depth)

    print('{} files to test'.format(len(test_files)))
    errors = np.zeros((7, len(test_files)), np.float32)

    args.output_dir = Path(args.output_dir)
    args.output_dir.makedirs_p()

    for j, sample in enumerate(tqdm(framework)):
        imgs = sample['imgs']
        intrinsics = sample['intrinsics'].copy()

        h,w,_ = imgs[0].shape
        if (not args.no_resize) and (h != args.img_height or w != args.img_width):
            imgs = [imresize(img, (args.img_height, args.img_width)).astype(np.float32) for img in imgs]
            intrinsics[0] *= args.img_width/w
            intrinsics[1] *= args.img_height/h

        intrinsics_inv = np.linalg.inv(intrinsics)

        intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).to(device)
        intrinsics_inv = torch.from_numpy(intrinsics_inv).unsqueeze(0).to(device)
        imgs = [torch.from_numpy(np.transpose(img, (2,0,1))) for img in imgs]
        imgs = torch.stack(imgs).unsqueeze(0).to(device)
        imgs = 2*(imgs/255 - 0.5)

        tgt_img = imgs[:,sample['tgt_index']]

        # Construct a batch of all possible stabilized pairs, with PoseNet or with GT orientation, will take the output closest to target mean depth
        if args.stabilize_from_GT:
            poses_GT = Variable(torch.from_numpy(sample['poses']).cuda()).unsqueeze(0)
            inv_poses_GT = invert_mat(poses_GT)
            tgt_pose = inv_poses_GT[:,sample['tgt_index']]
            inv_transform_matrices_tgt = compensate_pose(inv_poses_GT, tgt_pose)
        else:
            poses = pose_net(imgs)
            inv_transform_matrices = pose_vec2mat(poses, rotation_mode=args.rotation_mode)

            tgt_pose = inv_transform_matrices[:,sample['tgt_index']]
            inv_transform_matrices_tgt = compensate_pose(inv_transform_matrices, tgt_pose)

        stabilized_pairs = []
        corresponding_displ = []
        for i in range(seq_length):
            if i == sample['tgt_index']:
                continue
            img = imgs[:,i]
            img_pose = inv_transform_matrices_tgt[:,i]
            stab_img = inverse_rotate(img, img_pose[:,:,:3], intrinsics, intrinsics_inv)
            pair = torch.cat([stab_img, tgt_img], dim=1)  # [1, 6, H, W]
            stabilized_pairs.append(pair)

            GT_translations = sample['poses'][:,:,-1]
            real_displacement = np.linalg.norm(GT_translations[sample['tgt_index']] - GT_translations[i])
            corresponding_displ.append(real_displacement)
        stab_batch = torch.cat(stabilized_pairs)  # [seq, 6, H, W]
        depth_maps = depth_net(stab_batch)  # [seq, 1 , H/4, W/4]

        selected_depth, selected_index = select_best_map(depth_maps, target_mean_depthnet_output)

        pred_depth = selected_depth.cpu().data.numpy() * corresponding_displ[selected_index] / args.nominal_displacement

        if args.save_output:
            if j == 0:
                predictions = np.zeros((len(test_files), *pred_depth.shape))
            predictions[j] = 1/pred_depth

        gt_depth = sample['gt_depth']
        pred_depth_zoomed = zoom(pred_depth,
                                 (gt_depth.shape[0]/pred_depth.shape[0],
                                  gt_depth.shape[1]/pred_depth.shape[1])
                                 ).clip(args.min_depth, args.max_depth)
        if sample['mask'] is not None:
            pred_depth_zoomed_masked = pred_depth_zoomed[sample['mask']]
            gt_depth = gt_depth[sample['mask']]
        errors[:,j] = compute_errors(gt_depth, pred_depth_zoomed_masked)
        if args.log_best_worst:
            if best_error > errors[0,j]:
                best_error = errors[0,j]
                log_result(pred_depth_zoomed, sample['gt_depth'], stab_batch, selected_index, args.output_dir, 'best')
            if worst_error < errors[0,j]:
                worst_error = errors[0,j]
                log_result(pred_depth_zoomed, sample['gt_depth'], stab_batch, selected_index, args.output_dir, 'worst')

    mean_errors = errors.mean(1)
    error_names = ['abs_rel','sq_rel','rms','log_rms','a1','a2','a3']

    print("Results : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names))
    print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors))

    if args.save_output:
        np.save(args.output_dir/'predictions.npy', predictions)
def main():
    args = parser.parse_args()
    if args.gt_type == 'KITTI':
        from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework
    elif args.gt_type == 'stillbox':
        from stillbox_eval.depth_evaluation_utils import test_framework_stillbox as test_framework

    disp_net = DispNetS().to(device)
    weights = torch.load(args.pretrained_dispnet)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    if args.pretrained_posenet is None:
        print('no PoseNet specified, scale_factor will be determined by median ratio, which is kiiinda cheating\
            (but consistent with original paper)')
        seq_length = 0
    else:
        weights = torch.load(args.pretrained_posenet)
        seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3)
        pose_net = PoseExpNet(nb_ref_imgs=seq_length - 1, output_exp=False).to(device)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    dataset_dir = Path(args.dataset_dir)
    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = list(f.read().splitlines())
    else:
        test_files = [file.relpathto(dataset_dir) for file in sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])]

    framework = test_framework(dataset_dir, test_files, seq_length, args.min_depth, args.max_depth)

    print('{} files to test'.format(len(test_files)))
    errors = np.zeros((2, 7, len(test_files)), np.float32)
    if args.output_dir is not None:
        output_dir = Path(args.output_dir)
        output_dir.makedirs_p()

    for j, sample in enumerate(tqdm(framework)):
        tgt_img = sample['tgt']

        ref_imgs = sample['ref']

        h,w,_ = tgt_img.shape
        if (not args.no_resize) and (h != args.img_height or w != args.img_width):
            tgt_img = imresize(tgt_img, (args.img_height, args.img_width)).astype(np.float32)
            ref_imgs = [imresize(img, (args.img_height, args.img_width)).astype(np.float32) for img in ref_imgs]

        tgt_img = np.transpose(tgt_img, (2, 0, 1))
        ref_imgs = [np.transpose(img, (2,0,1)) for img in ref_imgs]

        tgt_img = torch.from_numpy(tgt_img).unsqueeze(0)
        tgt_img = ((tgt_img/255 - 0.5)/0.5).to(device)

        for i, img in enumerate(ref_imgs):
            img = torch.from_numpy(img).unsqueeze(0)
            img = ((img/255 - 0.5)/0.5).to(device)
            ref_imgs[i] = img

        pred_disp = disp_net(tgt_img).cpu().numpy()[0,0]

        if args.output_dir is not None:
            if j == 0:
                predictions = np.zeros((len(test_files), *pred_disp.shape))
            predictions[j] = 1/pred_disp

        gt_depth = sample['gt_depth']

        pred_depth = 1/pred_disp
        pred_depth_zoomed = zoom(pred_depth,
                                 (gt_depth.shape[0]/pred_depth.shape[0],
                                  gt_depth.shape[1]/pred_depth.shape[1])
                                 ).clip(args.min_depth, args.max_depth)
        if sample['mask'] is not None:
            pred_depth_zoomed = pred_depth_zoomed[sample['mask']]
            gt_depth = gt_depth[sample['mask']]

        if seq_length > 0:
            # Reorganize ref_imgs : tgt is middle frame but not necessarily the one used in DispNetS
            # (in case sample to test was in end or beginning of the image sequence)
            middle_index = seq_length//2
            tgt = ref_imgs[middle_index]
            reorganized_refs = ref_imgs[:middle_index] + ref_imgs[middle_index + 1:]
            _, poses = pose_net(tgt, reorganized_refs)
            mean_displacement_magnitude = poses[0,:,:3].norm(2,1).mean().item()

            scale_factor = sample['displacement'] / mean_displacement_magnitude
            errors[0,:,j] = compute_errors(gt_depth, pred_depth_zoomed*scale_factor)

       # scale_factor = np.median(gt_depth)/np.median(pred_depth_zoomed)
        scale_factor=1
        errors[1,:,j] = compute_errors(gt_depth, pred_depth_zoomed*scale_factor)

    mean_errors = errors.mean(2)
    error_names = ['abs_rel','sq_rel','rms','log_rms','a1','a2','a3']
    if args.pretrained_posenet:
        print("Results with scale factor determined by PoseNet : ")
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors[0]))

    print("Results with scale factor determined by GT/prediction ratio (like the original paper) : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names))
    print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors[1]))

    if args.output_dir is not None:
        np.save(output_dir/'predictions.npy', predictions)
示例#3
0
def main():
    args = parser.parse_args()
    if args.gt_type == 'KITTI':
        from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework

    disp_net = DispNetS().cuda()
    weights = torch.load(args.pretrained_dispnet)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    if args.pretrained_posenet is None:
        print(
            'no PoseNet specified, scale_factor will be determined by median ratio, which is kiiinda cheating\
            (but consistent with original paper)')
        seq_length = 0
    else:
        weights = torch.load(args.pretrained_posenet)
        seq_length = int(weights['state_dict']['conv1.0.weight'].size(1) / 3)
        pose_net = PoseExpNet(nb_ref_imgs=seq_length - 1,
                              output_exp=False).cuda()
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    dataset_dir = Path(args.dataset_dir)
    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = list(f.read().splitlines())
    else:
        test_files = [
            file.relpathto(dataset_dir) for file in sum([
                dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts
            ], [])
        ]

    framework = test_framework(dataset_dir, test_files, seq_length,
                               args.min_depth, args.max_depth)
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    print('{} files to test'.format(len(test_files)))
    errors = np.zeros((2, 7, len(test_files)), np.float32)

    for j, sample in enumerate(tqdm(framework)):
        tgt_img = sample['tgt']

        ref_imgs = sample['ref']

        h, w, _ = tgt_img.shape
        if (not args.no_resize) and (h != args.img_height
                                     or w != args.img_width):
            tgt_img = imresize(
                tgt_img, (args.img_height, args.img_width)).astype(np.float32)
            ref_imgs = [
                imresize(img,
                         (args.img_height, args.img_width)).astype(np.float32)
                for img in ref_imgs
            ]

        tgt_img = np.transpose(tgt_img, (2, 0, 1))
        ref_imgs = [np.transpose(img, (2, 0, 1)) for img in ref_imgs]

        tgt_img = torch.from_numpy(tgt_img).unsqueeze(0)
        tgt_img = ((tgt_img / 255 - 0.5) / 0.2).cuda()
        tgt_img_var = Variable(tgt_img, volatile=True)

        ref_imgs_var = []
        for i, img in enumerate(ref_imgs):
            img = torch.from_numpy(img).unsqueeze(0)
            img = ((img / 255 - 0.5) / 0.2).cuda()
            ref_imgs_var.append(Variable(img, volatile=True))

        pred_disp = disp_net(tgt_img_var).data.cpu().numpy()[0, 0]

        gt_depth = sample['gt_depth']

        pred_depth = 1 / pred_disp
        pred_depth_zoomed = zoom(
            pred_depth,
            (gt_depth.shape[0] / pred_depth.shape[0], gt_depth.shape[1] /
             pred_depth.shape[1])).clip(args.min_depth, args.max_depth)
        if sample['mask'] is not None:
            pred_depth_zoomed = pred_depth_zoomed[sample['mask']]
            gt_depth = gt_depth[sample['mask']]
        if seq_length > 0:
            _, poses = pose_net(tgt_img_var, ref_imgs_var)
            displacements = poses[0, :, :3].norm(
                2, 1).cpu().data.numpy()  # shape [1 - seq_length]

            scale_factors = (sample['displacements'] /
                             displacements)[sample['displacements'] > 0]
            scale_factors = [
                s1 / s2
                for s1, s2 in zip(sample['displacements'], displacements)
                if s1 > 0
            ]
            scale_factor = np.mean(
                scale_factors) if len(scale_factors) > 0 else 0
            if len(scale_factors) == 0:
                print('not good ! ', sample['path'], sample['displacements'])
            errors[0, :, j] = compute_errors(gt_depth,
                                             pred_depth_zoomed * scale_factor)

        scale_factor = np.median(gt_depth) / np.median(pred_depth_zoomed)
        errors[1, :, j] = compute_errors(gt_depth,
                                         pred_depth_zoomed * scale_factor)

    mean_errors = errors.mean(2)
    error_names = ['abs_rel', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3']
    if args.pretrained_posenet:
        print("Results with scale factor determined by PoseNet : ")
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            *error_names))
        print(
            "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
            .format(*mean_errors[0]))

    print(
        "Results with scale factor determined by GT/prediction ratio (like the original paper) : "
    )
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
        *error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
        format(*mean_errors[1]))
def main():
    args = parser.parse_args()
    if args.gt_type == 'KITTI':
        from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework
    elif args.gt_type == 'stillbox':
        from stillbox_eval.depth_evaluation_utils import test_framework_stillbox as test_framework
    elif args.gt_type == 'pfm':
        from pfm_eval.depth_evaluation_utils import test_framework_stillbox as test_framework

    disp_net = getattr(models, args.dispnet)().cuda()
    weights = torch.load(args.pretrained_dispnet)
    disp_net.load_state_dict(weights['state_dict'])
    disp_net.eval()

    if args.pretrained_posenet is None:
        print(
            'no PoseNet specified, scale_factor will be determined by median ratio, which is kiiinda cheating\
            (but consistent with original paper)')
        seq_length = 0
    else:
        weights = torch.load(args.pretrained_posenet)
        seq_length = int(weights['state_dict']['conv1.0.weight'].size(1) / 3)
        pose_net = getattr(models, args.posenet)(nb_ref_imgs=seq_length - 1,
                                                 output_exp=False).cuda()
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    dataset_dir = Path(args.dataset_dir)
    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = list(f.read().splitlines())
    else:
        test_files = [
            file.relpathto(dataset_dir) for file in sum([
                dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts
            ], [])
        ]

    framework = test_framework(dataset_dir, test_files, seq_length,
                               args.min_depth, args.max_depth)

    print('{} files to test'.format(len(test_files)))
    errors = np.zeros((2, 7, len(test_files)), np.float32)
    if args.output_dir is not None:
        output_dir = Path(args.output_dir)
        viz_dir = output_dir / 'viz'
        print("Saving output to", viz_dir)
        output_dir.makedirs_p()
        viz_dir.makedirs_p()

    for j, sample in enumerate(tqdm(framework)):
        tgt_img = sample['tgt']
        ref_imgs = sample['ref']

        h, w, _ = tgt_img.shape
        if (not args.no_resize) and (h != args.img_height
                                     or w != args.img_width):
            tgt_img = imresize(
                tgt_img, (args.img_height, args.img_width)).astype(np.float32)
            ref_imgs = [
                imresize(img,
                         (args.img_height, args.img_width)).astype(np.float32)
                for img in ref_imgs
            ]

        tgt_img = np.transpose(tgt_img, (2, 0, 1))
        ref_imgs = [np.transpose(img, (2, 0, 1)) for img in ref_imgs]

        tgt_img = torch.from_numpy(tgt_img).unsqueeze(0)
        tgt_img = ((tgt_img / 255 - 0.5) / 0.5).cuda()
        tgt_img_var = Variable(tgt_img, volatile=True)

        ref_imgs_var = []
        for i, img in enumerate(ref_imgs):
            img = torch.from_numpy(img).unsqueeze(0)
            img = ((img / 255 - 0.5) / 0.5).cuda()
            ref_imgs_var.append(Variable(img, volatile=True))

        pred_disp = disp_net(tgt_img_var)
        if args.spatial_normalize:
            pred_disp = spatial_normalize(pred_disp)
        pred_disp = pred_disp.data.cpu().numpy()[0, 0]

        if args.output_dir is not None:
            if j == 0:
                predictions = np.zeros((len(test_files), *pred_disp.shape))
            predictions[j] = 1 / pred_disp

            depth_viz = tensor2array(torch.FloatTensor(pred_disp),
                                     max_value=None,
                                     colormap='magma')
            depth_viz = np.transpose(depth_viz, (1, 2, 0))
            depth_viz_im = Image.fromarray((255 * depth_viz).astype('uint8'))
            depth_viz_im.save(viz_dir / str(j).zfill(4) + 'depth.png')

    mean_errors = errors.mean(2)
    error_names = ['abs_rel', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3']
    if args.pretrained_posenet:
        print("Results with scale factor determined by PoseNet : ")
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            *error_names))
        print(
            "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
            .format(*mean_errors[0]))

    print(
        "Results with scale factor determined by GT/prediction ratio (like the original paper) : "
    )
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
        *error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
        format(*mean_errors[1]))

    if args.output_dir is not None:
        np.save(output_dir / 'predictions.npy', predictions)
示例#5
0
def main():
    args = parser.parse_args()
    if args.gt_type == 'KITTI':
        from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework

    disp_net = DispNetS().to(device)
    weights = torch.load(args.pretrained_dispnet, map_location='cpu')
    disp_net.load_state_dict(weights['disp_net_state_dict'])
    disp_net.eval()

    seq_length = 1

    dataset_dir = Path(args.dataset_dir)
    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = list(f.read().splitlines())
    else:
        test_files = [file.relpathto(dataset_dir) for file in sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])]

    framework = test_framework(dataset_dir, test_files, seq_length,
                               args.min_depth, args.max_depth,
                               use_gps=args.gps)

    print('{} files to test'.format(len(test_files)))
    errors = np.zeros((2, 9, len(test_files)), np.float32)
    if args.output_dir is not None:
        output_dir = Path(args.output_dir)
        output_dir.makedirs_p()

    #predictions = np.load('/ceph/raunaks/old/t2net_signet/checkpoints/pure_geonet/predicted_depth.npy')
    for j, sample in enumerate(tqdm(framework)):
        tgt_img = sample['tgt']

        ref_imgs = sample['ref']

        h,w,_ = tgt_img.shape
        if (not args.no_resize) and (h != args.img_height or w != args.img_width):
            
            tgt_img = cv2.resize(tgt_img, (args.img_width, args.img_height)).astype(np.float32)
            ref_imgs = [cv2.resize(img, (args.img_width, args.img_height)).astype(np.float32) for img in ref_imgs]
            #tgt_img = imresize(tgt_img, (args.img_height, args.img_width)).astype(np.float32)
            #ref_imgs = [imresize(img, (args.img_height, args.img_width)).astype(np.float32) for img in ref_imgs]

        tgt_img = np.transpose(tgt_img, (2, 0, 1))
        ref_imgs = [np.transpose(img, (2,0,1)) for img in ref_imgs]

        tgt_img = torch.from_numpy(tgt_img).unsqueeze(0)
        tgt_img = ((tgt_img/255 - 0.5)/0.5).to(device)

        for i, img in enumerate(ref_imgs):
            img = torch.from_numpy(img).unsqueeze(0)
            img = ((img/255 - 0.5)/0.5).to(device)
            ref_imgs[i] = img

        pred_disp = disp_net(tgt_img).cpu().numpy()[0,0]

        if args.output_dir is not None:
            if j == 0:
                predictions = np.zeros((len(test_files), *pred_disp.shape))
            predictions[j] = 1/pred_disp

        gt_depth = sample['gt_depth']

        pred_depth = 1/pred_disp
        
        #pred_depth = predictions[j]
        pred_depth_zoomed = zoom(pred_depth,
                                 (gt_depth.shape[0]/pred_depth.shape[0],
                                  gt_depth.shape[1]/pred_depth.shape[1])
                                 ).clip(args.min_depth, args.max_depth)
        if sample['mask'] is not None:
            pred_depth_zoomed = pred_depth_zoomed[sample['mask']]
            gt_depth = gt_depth[sample['mask']]

        scale_factor = np.median(gt_depth)/np.median(pred_depth_zoomed)
        errors[1,:,j] = compute_errors(gt_depth, pred_depth_zoomed*scale_factor)

    mean_errors = errors.mean(2)
    error_names = ['abs_diff', 'abs_rel','sq_rel','rms','log_rms', 'abs_log', 'a1','a2','a3']

    print("Results with scale factor determined by GT/prediction ratio (like the original paper) : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names))
    print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors[1]))

    if args.output_dir is not None:
        np.save(output_dir/'predictions.npy', predictions)
def main():
    args = parser.parse_args()
    if args.gt_type == 'KITTI':
        from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework
    #print("device :",device)
    disp_net_enc = SharedEncoder.SharedEncoderMain().double().to(device)
    weights = torch.load(args.pretrained_dispnet_enc,
                         map_location=lambda storage, loc: storage)
    disp_net_enc.load_state_dict(weights)
    disp_net_enc.eval()

    disp_net_dec = DepthDecoder.DepthDecoder().double().to(device)
    weights = torch.load(args.pretrained_dispnet_dec,
                         map_location=lambda storage, loc: storage)
    #print("weights:",weights)
    disp_net_dec.load_state_dict(weights)
    disp_net_dec.eval()

    if args.pretrained_posenet_dec is None:
        print(
            'no PoseNet specified, scale_factor will be determined by median ratio, which is kiiinda cheating\
            (but consistent with original paper)')
        seq_length = 1
    else:
        #pose_net_dec = PoseCopy.PoseExpNet().double().to(device)
        pose_net_dec = PoseNetwork.PoseDecoder().double().to(device)
        weights = torch.load(args.pretrained_posenet_dec,
                             map_location=lambda storage, loc: storage)
        seq_length = int(weights['conv1.0.weight'].size(1) / 3)
        # print("weights:",weights)
        pose_net_dec.load_state_dict(weights)
        pose_net_dec.eval()
        print("seq:", seq_length)
        seq_length = 3

    dataset_dir = Path(args.dataset_dir)
    if args.dataset_list is not None:
        with open(args.dataset_list, 'r') as f:
            test_files = list(f.read().splitlines())
    else:
        test_files = [
            file.relpathto(dataset_dir) for file in sum([
                dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts
            ], [])
        ]

    framework = test_framework(dataset_dir,
                               test_files,
                               seq_length,
                               args.min_depth,
                               args.max_depth,
                               use_gps=args.gps)

    print('{} files to test'.format(len(test_files)))
    errors = np.zeros((2, 9, len(test_files)), np.float32)
    if args.output_dir is not None:
        output_dir = Path(args.output_dir)
        output_dir.makedirs_p()

    for j, sample in enumerate(tqdm(framework)):
        tgt_img = sample['tgt']

        ref_imgs = sample['ref']

        h, w, _ = tgt_img.shape
        if (not args.no_resize) and (h != args.img_height
                                     or w != args.img_width):
            tgt_img = imresize(
                tgt_img, (args.img_height, args.img_width)).astype(np.float32)
            ref_imgs = [
                imresize(img,
                         (args.img_height, args.img_width)).astype(np.float32)
                for img in ref_imgs
            ]

        tgt_img = np.transpose(tgt_img, (2, 0, 1))
        ref_imgs = [np.transpose(img, (2, 0, 1)) for img in ref_imgs]

        tgt_img = torch.from_numpy(tgt_img).unsqueeze(0)
        tgt_img = ((tgt_img / 255 - 0.5) / 0.5).to(device)

        for i, img in enumerate(ref_imgs):
            img = torch.from_numpy(img).unsqueeze(0)
            img = ((img / 255 - 0.5) / 0.5).to(device)
            ref_imgs[i] = img

        econv = disp_net_enc(tgt_img.double())
        pred_disp = disp_net_dec(tgt_img.double(), econv).cpu().numpy()[0, 0]

        if args.output_dir is not None:
            if j == 0:
                predictions = np.zeros((len(test_files), *pred_disp.shape))
            predictions[j] = 1 / pred_disp

        gt_depth = sample['gt_depth']

        pred_depth = 1 / pred_disp
        pred_depth_zoomed = zoom(
            pred_depth,
            (gt_depth.shape[0] / pred_depth.shape[0], gt_depth.shape[1] /
             pred_depth.shape[1])).clip(args.min_depth, args.max_depth)
        if sample['mask'] is not None:
            pred_depth_zoomed = pred_depth_zoomed[sample['mask']]
            gt_depth = gt_depth[sample['mask']]

        if seq_length > 1:

            middle_index = seq_length // 2
            tgt = ref_imgs[middle_index]

            reorganized_refs = ref_imgs[:middle_index] + ref_imgs[
                middle_index + 1:]

            econv = disp_net_enc(torch.cat(ref_imgs, dim=0).double())
            poses, a1, a2 = pose_net_dec(econv[4][0:1], econv[4][1:2],
                                         econv[4][2:3])

            displacement_magnitudes = poses[0, :, :3].norm(2, 1).cpu().numpy()

            scale_factor = np.mean(sample['displacements'] /
                                   displacement_magnitudes)
            errors[0, :, j] = compute_errors(gt_depth,
                                             pred_depth_zoomed * scale_factor)

        scale_factor = np.median(gt_depth) / np.median(pred_depth_zoomed)
        errors[1, :, j] = compute_errors(gt_depth,
                                         pred_depth_zoomed * scale_factor)

    mean_errors = errors.mean(2)
    error_names = [
        'abs_diff', 'abs_rel', 'sq_rel', 'rms', 'log_rms', 'abs_log', 'a1',
        'a2', 'a3'
    ]
    if args.pretrained_posenet_dec:
        print("Results with scale factor determined by PoseNet : ")
        print(
            "{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}"
            .format(*error_names))
        print(
            "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
            .format(*mean_errors[0]))

    print(
        "Results with scale factor determined by GT/prediction ratio (like the original paper) : "
    )
    print(
        "{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}"
        .format(*error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*mean_errors[1]))

    if args.output_dir is not None:
        np.save(output_dir / 'predictions.npy', predictions)