Exemple #1
0
def predict():
    scene_name = Config.test_online_scene_path.split("/")[-1]
    system_name = "{}_{}_{}_{}_dvmvs_fusionnet_online".format(
        scene_name, Config.test_image_width, Config.test_image_height,
        Config.test_n_measurement_frames)

    print("Predicting with System:", system_name)
    print("# of Measurement Frames:", Config.test_n_measurement_frames)

    device = torch.device("cuda")
    feature_extractor = FeatureExtractor()
    feature_shrinker = FeatureShrinker()
    cost_volume_encoder = CostVolumeEncoder()
    lstm_fusion = LSTMFusion()
    cost_volume_decoder = CostVolumeDecoder()

    feature_extractor = feature_extractor.to(device)
    feature_shrinker = feature_shrinker.to(device)
    cost_volume_encoder = cost_volume_encoder.to(device)
    lstm_fusion = lstm_fusion.to(device)
    cost_volume_decoder = cost_volume_decoder.to(device)

    model = [
        feature_extractor, feature_shrinker, cost_volume_encoder, lstm_fusion,
        cost_volume_decoder
    ]

    for i in range(len(model)):
        try:
            checkpoint = sorted(Path("weights").files())[i]
            weights = torch.load(checkpoint)
            model[i].load_state_dict(weights)
            model[i].eval()
            print("Loaded weights for", checkpoint)
        except Exception as e:
            print(e)
            print("Could not find the checkpoint for module", i)
            exit(1)

    feature_extractor = model[0]
    feature_shrinker = model[1]
    cost_volume_encoder = model[2]
    lstm_fusion = model[3]
    cost_volume_decoder = model[4]

    warp_grid = get_warp_grid_for_cost_volume_calculation(
        width=int(Config.test_image_width / 2),
        height=int(Config.test_image_height / 2),
        device=device)

    scale_rgb = 255.0
    mean_rgb = [0.485, 0.456, 0.406]
    std_rgb = [0.229, 0.224, 0.225]

    min_depth = 0.25
    max_depth = 20.0
    n_depth_levels = 64

    scene_folder = Path(Config.test_online_scene_path)

    scene = scene_folder.split("/")[-1]
    print("Predicting for scene:", scene)

    keyframe_buffer = KeyframeBuffer(
        buffer_size=Config.test_keyframe_buffer_size,
        keyframe_pose_distance=Config.test_keyframe_pose_distance,
        optimal_t_score=Config.test_optimal_t_measure,
        optimal_R_score=Config.test_optimal_R_measure,
        store_return_indices=False)

    K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32)
    poses = np.fromfile(scene_folder / "poses.txt", dtype=float,
                        sep="\n ").reshape((-1, 4, 4))
    image_filenames = sorted((scene_folder / 'images').files("*.png"))
    depth_filenames = sorted((scene_folder / 'depth').files("*.png"))

    inference_timer = InferenceTimer()

    lstm_state = None
    previous_depth = None
    previous_pose = None

    predictions = []
    reference_depths = []
    with torch.no_grad():
        for i in tqdm(range(0, len(poses))):
            reference_pose = poses[i]
            reference_image = load_image(image_filenames[i])
            reference_depth = cv2.imread(depth_filenames[i],
                                         -1).astype(float) / 1000.0

            # POLL THE KEYFRAME BUFFER
            response = keyframe_buffer.try_new_keyframe(
                reference_pose, reference_image)
            if response == 0 or response == 2 or response == 4 or response == 5:
                continue
            elif response == 3:
                previous_depth = None
                previous_pose = None
                lstm_state = None
                continue

            preprocessor = PreprocessImage(
                K=K,
                old_width=reference_image.shape[1],
                old_height=reference_image.shape[0],
                new_width=Config.test_image_width,
                new_height=Config.test_image_height,
                distortion_crop=Config.test_distortion_crop,
                perform_crop=Config.test_perform_crop)

            reference_image = preprocessor.apply_rgb(image=reference_image,
                                                     scale_rgb=scale_rgb,
                                                     mean_rgb=mean_rgb,
                                                     std_rgb=std_rgb)
            reference_depth = preprocessor.apply_depth(reference_depth)
            reference_image_torch = torch.from_numpy(
                np.transpose(reference_image,
                             (2, 0, 1))).float().to(device).unsqueeze(0)
            reference_pose_torch = torch.from_numpy(reference_pose).float().to(
                device).unsqueeze(0)

            full_K_torch = torch.from_numpy(
                preprocessor.get_updated_intrinsics()).float().to(
                    device).unsqueeze(0)

            half_K_torch = full_K_torch.clone().cuda()
            half_K_torch[:, 0:2, :] = half_K_torch[:, 0:2, :] / 2.0

            lstm_K_bottom = full_K_torch.clone().cuda()
            lstm_K_bottom[:, 0:2, :] = lstm_K_bottom[:, 0:2, :] / 32.0

            measurement_poses_torch = []
            measurement_images_torch = []
            measurement_frames = keyframe_buffer.get_best_measurement_frames(
                Config.test_n_measurement_frames)
            for (measurement_pose, measurement_image) in measurement_frames:
                measurement_image = preprocessor.apply_rgb(
                    image=measurement_image,
                    scale_rgb=scale_rgb,
                    mean_rgb=mean_rgb,
                    std_rgb=std_rgb)
                measurement_image_torch = torch.from_numpy(
                    np.transpose(measurement_image,
                                 (2, 0, 1))).float().to(device).unsqueeze(0)
                measurement_pose_torch = torch.from_numpy(
                    measurement_pose).float().to(device).unsqueeze(0)
                measurement_images_torch.append(measurement_image_torch)
                measurement_poses_torch.append(measurement_pose_torch)

            inference_timer.record_start_time()

            measurement_feature_halfs = []
            for measurement_image_torch in measurement_images_torch:
                measurement_feature_half, _, _, _ = feature_shrinker(
                    *feature_extractor(measurement_image_torch))
                measurement_feature_halfs.append(measurement_feature_half)

            reference_feature_half, reference_feature_quarter, \
            reference_feature_one_eight, reference_feature_one_sixteen = feature_shrinker(*feature_extractor(reference_image_torch))

            cost_volume = cost_volume_fusion(image1=reference_feature_half,
                                             image2s=measurement_feature_halfs,
                                             pose1=reference_pose_torch,
                                             pose2s=measurement_poses_torch,
                                             K=half_K_torch,
                                             warp_grid=warp_grid,
                                             min_depth=min_depth,
                                             max_depth=max_depth,
                                             n_depth_levels=n_depth_levels,
                                             device=device,
                                             dot_product=True)

            skip0, skip1, skip2, skip3, bottom = cost_volume_encoder(
                features_half=reference_feature_half,
                features_quarter=reference_feature_quarter,
                features_one_eight=reference_feature_one_eight,
                features_one_sixteen=reference_feature_one_sixteen,
                cost_volume=cost_volume)

            if previous_depth is not None:
                depth_estimation = get_non_differentiable_rectangle_depth_estimation(
                    reference_pose_torch=reference_pose_torch,
                    measurement_pose_torch=previous_pose,
                    previous_depth_torch=previous_depth,
                    full_K_torch=full_K_torch,
                    half_K_torch=half_K_torch,
                    original_height=Config.test_image_height,
                    original_width=Config.test_image_width)
                depth_estimation = torch.nn.functional.interpolate(
                    input=depth_estimation,
                    scale_factor=(1.0 / 16.0),
                    mode="nearest")
            else:
                depth_estimation = torch.zeros(
                    size=(1, 1, int(Config.test_image_height / 32.0),
                          int(Config.test_image_width / 32.0))).to(device)

            lstm_state = lstm_fusion(current_encoding=bottom,
                                     current_state=lstm_state,
                                     previous_pose=previous_pose,
                                     current_pose=reference_pose_torch,
                                     estimated_current_depth=depth_estimation,
                                     camera_matrix=lstm_K_bottom)

            prediction, _, _, _, _ = cost_volume_decoder(
                reference_image_torch, skip0, skip1, skip2, skip3,
                lstm_state[0])
            previous_depth = prediction.view(1, 1, Config.test_image_height,
                                             Config.test_image_width)
            previous_pose = reference_pose_torch

            inference_timer.record_end_time_and_elapsed_time()

            prediction = prediction.cpu().numpy().squeeze()
            reference_depths.append(reference_depth)
            predictions.append(prediction)

            if Config.test_visualize:
                visualize_predictions(
                    numpy_reference_image=reference_image,
                    numpy_measurement_image=measurement_image,
                    numpy_predicted_depth=prediction,
                    normalization_mean=mean_rgb,
                    normalization_std=std_rgb,
                    normalization_scale=scale_rgb,
                    depth_multiplier_for_visualization=5000)

        inference_timer.print_statistics()

        save_results(predictions=predictions,
                     groundtruths=reference_depths,
                     system_name=system_name,
                     scene_name=scene,
                     save_folder=".")
def predict():
    print("System: PAIRNET")

    device = torch.device("cuda")
    feature_extractor = FeatureExtractor()
    feature_shrinker = FeatureShrinker()
    cost_volume_encoder = CostVolumeEncoder()
    cost_volume_decoder = CostVolumeDecoder()

    feature_extractor = feature_extractor.to(device)
    feature_shrinker = feature_shrinker.to(device)
    cost_volume_encoder = cost_volume_encoder.to(device)
    cost_volume_decoder = cost_volume_decoder.to(device)

    model = [
        feature_extractor, feature_shrinker, cost_volume_encoder,
        cost_volume_decoder
    ]

    for i in range(len(model)):
        try:
            checkpoint = sorted(Path("weights").files())[i]
            weights = torch.load(checkpoint)
            model[i].load_state_dict(weights)
            model[i].eval()
            print("Loaded weights for", checkpoint)
        except Exception as e:
            print(e)
            print("Could not find the checkpoint for module", i)
            exit(1)

    feature_extractor = model[0]
    feature_shrinker = model[1]
    cost_volume_encoder = model[2]
    cost_volume_decoder = model[3]

    warp_grid = get_warp_grid_for_cost_volume_calculation(
        width=int(Config.test_image_width / 2),
        height=int(Config.test_image_height / 2),
        device=device)

    scale_rgb = 255.0
    mean_rgb = [0.485, 0.456, 0.406]
    std_rgb = [0.229, 0.224, 0.225]

    min_depth = 0.25
    max_depth = 20.0
    n_depth_levels = 64

    data_path = Path(Config.test_offline_data_path)
    if Config.test_dataset_name is None:
        keyframe_index_files = sorted(
            (Path(Config.test_offline_data_path) / "indices").files())
    else:
        keyframe_index_files = sorted(
            (Path(Config.test_offline_data_path) /
             "indices").files("*" + Config.test_dataset_name + "*"))
    for iteration, keyframe_index_file in enumerate(keyframe_index_files):
        keyframing_type, dataset_name, scene_name, _, n_measurement_frames = keyframe_index_file.split(
            "/")[-1].split("+")

        scene_folder = data_path / dataset_name / scene_name
        print("Predicting for scene:", dataset_name + "-" + scene_name, " - ",
              iteration, "/", len(keyframe_index_files))

        keyframe_index_file_lines = np.loadtxt(keyframe_index_file,
                                               dtype=str,
                                               delimiter="\n")

        K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32)
        poses = np.fromfile(scene_folder / "poses.txt", dtype=float,
                            sep="\n ").reshape((-1, 4, 4))
        image_filenames = sorted((scene_folder / 'images').files("*.png"))
        depth_filenames = sorted((scene_folder / 'depth').files("*.png"))

        input_filenames = []
        for image_filename in image_filenames:
            input_filenames.append(image_filename.split("/")[-1])

        inference_timer = InferenceTimer()

        predictions = []
        reference_depths = []
        with torch.no_grad():
            for i in tqdm(range(0, len(keyframe_index_file_lines))):

                keyframe_index_file_line = keyframe_index_file_lines[i]

                if keyframe_index_file_line == "TRACKING LOST":
                    continue
                else:
                    current_input_filenames = keyframe_index_file_line.split(
                        " ")
                    current_indices = [
                        input_filenames.index(current_input_filenames[x])
                        for x in range(len(current_input_filenames))
                    ]

                reference_index = current_indices[0]
                measurement_indices = current_indices[1:]

                reference_pose = poses[reference_index]
                reference_image = load_image(image_filenames[reference_index])
                reference_depth = cv2.imread(depth_filenames[reference_index],
                                             -1).astype(float) / 1000.0

                preprocessor = PreprocessImage(
                    K=K,
                    old_width=reference_image.shape[1],
                    old_height=reference_image.shape[0],
                    new_width=Config.test_image_width,
                    new_height=Config.test_image_height,
                    distortion_crop=Config.test_distortion_crop,
                    perform_crop=Config.test_perform_crop)

                reference_image = preprocessor.apply_rgb(image=reference_image,
                                                         scale_rgb=scale_rgb,
                                                         mean_rgb=mean_rgb,
                                                         std_rgb=std_rgb)
                reference_depth = preprocessor.apply_depth(reference_depth)
                reference_image_torch = torch.from_numpy(
                    np.transpose(reference_image,
                                 (2, 0, 1))).float().to(device).unsqueeze(0)
                reference_pose_torch = torch.from_numpy(
                    reference_pose).float().to(device).unsqueeze(0)

                measurement_poses_torch = []
                measurement_images_torch = []
                for measurement_index in measurement_indices:
                    measurement_image = load_image(
                        image_filenames[measurement_index])
                    measurement_image = preprocessor.apply_rgb(
                        image=measurement_image,
                        scale_rgb=scale_rgb,
                        mean_rgb=mean_rgb,
                        std_rgb=std_rgb)
                    measurement_image_torch = torch.from_numpy(
                        np.transpose(
                            measurement_image,
                            (2, 0, 1))).float().to(device).unsqueeze(0)
                    measurement_pose_torch = torch.from_numpy(
                        poses[measurement_index]).float().to(device).unsqueeze(
                            0)
                    measurement_images_torch.append(measurement_image_torch)
                    measurement_poses_torch.append(measurement_pose_torch)

                full_K_torch = torch.from_numpy(
                    preprocessor.get_updated_intrinsics()).float().to(
                        device).unsqueeze(0)

                half_K_torch = full_K_torch.clone().cuda()
                half_K_torch[:, 0:2, :] = half_K_torch[:, 0:2, :] / 2.0

                inference_timer.record_start_time()

                measurement_feature_halfs = []
                for measurement_image_torch in measurement_images_torch:
                    measurement_feature_half, _, _, _ = feature_shrinker(
                        *feature_extractor(measurement_image_torch))
                    measurement_feature_halfs.append(measurement_feature_half)

                reference_feature_half, reference_feature_quarter, \
                reference_feature_one_eight, reference_feature_one_sixteen = feature_shrinker(*feature_extractor(reference_image_torch))

                cost_volume = cost_volume_fusion(
                    image1=reference_feature_half,
                    image2s=measurement_feature_halfs,
                    pose1=reference_pose_torch,
                    pose2s=measurement_poses_torch,
                    K=half_K_torch,
                    warp_grid=warp_grid,
                    min_depth=min_depth,
                    max_depth=max_depth,
                    n_depth_levels=n_depth_levels,
                    device=device,
                    dot_product=True)

                skip0, skip1, skip2, skip3, bottom = cost_volume_encoder(
                    features_half=reference_feature_half,
                    features_quarter=reference_feature_quarter,
                    features_one_eight=reference_feature_one_eight,
                    features_one_sixteen=reference_feature_one_sixteen,
                    cost_volume=cost_volume)

                prediction, _, _, _, _ = cost_volume_decoder(
                    reference_image_torch, skip0, skip1, skip2, skip3, bottom)

                inference_timer.record_end_time_and_elapsed_time()

                prediction = prediction.cpu().numpy().squeeze()
                reference_depths.append(reference_depth)
                predictions.append(prediction)

                if Config.test_visualize:
                    visualize_predictions(
                        numpy_reference_image=reference_image,
                        numpy_measurement_image=measurement_image,
                        numpy_predicted_depth=prediction,
                        normalization_mean=mean_rgb,
                        normalization_std=std_rgb,
                        normalization_scale=scale_rgb)

        inference_timer.print_statistics()

        system_name = "{}_{}_{}_{}_{}_dvmvs_pairnet".format(
            keyframing_type, dataset_name, Config.test_image_width,
            Config.test_image_height, n_measurement_frames)

        save_results(predictions=predictions,
                     groundtruths=reference_depths,
                     system_name=system_name,
                     scene_name=scene_name,
                     save_folder=Config.test_result_folder)
def predict():
    predict_with_finetuned = True

    if predict_with_finetuned:
        extension = "finetuned"
    else:
        extension = "without_ft"

    input_image_width = 320
    input_image_height = 240

    print("System: DPSNET, is_finetuned = ", predict_with_finetuned)

    device = torch.device('cuda')
    dpsnet = PSNet(64, 0.5)

    if predict_with_finetuned:
        weights = torch.load(Path("finetuned-weights").files("*dpsnet*")[0])
    else:
        weights = torch.load(
            Path("original-weights").files("*dpsnet*")[0])['state_dict']

    dpsnet.load_state_dict(weights)
    dpsnet = dpsnet.to(device)
    dpsnet.eval()

    scale_rgb = 255.0
    mean_rgb = [0.5, 0.5, 0.5]
    std_rgb = [0.5, 0.5, 0.5]

    data_path = Path(Config.test_offline_data_path)
    if Config.test_dataset_name is None:
        keyframe_index_files = sorted(
            (Path(Config.test_offline_data_path) / "indices").files())
    else:
        keyframe_index_files = sorted(
            (Path(Config.test_offline_data_path) /
             "indices").files("*" + Config.test_dataset_name + "*"))
    for iteration, keyframe_index_file in enumerate(keyframe_index_files):
        keyframing_type, dataset_name, scene_name, _, n_measurement_frames = keyframe_index_file.split(
            "/")[-1].split("+")

        scene_folder = data_path / dataset_name / scene_name
        print("Predicting for scene:", dataset_name + "-" + scene_name, " - ",
              iteration, "/", len(keyframe_index_files))

        keyframe_index_file_lines = np.loadtxt(keyframe_index_file,
                                               dtype=str,
                                               delimiter="\n")

        K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32)
        poses = np.fromfile(scene_folder / "poses.txt", dtype=float,
                            sep="\n ").reshape((-1, 4, 4))
        image_filenames = sorted((scene_folder / 'images').files("*.png"))
        depth_filenames = sorted((scene_folder / 'depth').files("*.png"))

        input_filenames = []
        for image_filename in image_filenames:
            input_filenames.append(image_filename.split("/")[-1])

        inference_timer = InferenceTimer()

        predictions = []
        reference_depths = []
        with torch.no_grad():
            for i in tqdm(range(0, len(keyframe_index_file_lines))):

                keyframe_index_file_line = keyframe_index_file_lines[i]

                if keyframe_index_file_line == "TRACKING LOST":
                    continue
                else:
                    current_input_filenames = keyframe_index_file_line.split(
                        " ")
                    current_indices = [
                        input_filenames.index(current_input_filenames[x])
                        for x in range(len(current_input_filenames))
                    ]

                reference_index = current_indices[0]
                measurement_indices = current_indices[1:]

                reference_pose = poses[reference_index]
                reference_image = load_image(image_filenames[reference_index])
                reference_depth = cv2.imread(depth_filenames[reference_index],
                                             -1).astype(float) / 1000.0

                preprocessor = PreprocessImage(
                    K=K,
                    old_width=reference_image.shape[1],
                    old_height=reference_image.shape[0],
                    new_width=input_image_width,
                    new_height=input_image_height,
                    distortion_crop=0,
                    perform_crop=False)

                reference_image = preprocessor.apply_rgb(image=reference_image,
                                                         scale_rgb=scale_rgb,
                                                         mean_rgb=mean_rgb,
                                                         std_rgb=std_rgb)
                reference_depth = preprocessor.apply_depth(reference_depth)
                reference_image_torch = torch.from_numpy(
                    np.transpose(reference_image,
                                 (2, 0, 1))).float().to(device).unsqueeze(0)

                measurement_poses_torch = []
                measurement_images_torch = []
                for measurement_index in measurement_indices:
                    measurement_image = load_image(
                        image_filenames[measurement_index])
                    measurement_image = preprocessor.apply_rgb(
                        image=measurement_image,
                        scale_rgb=scale_rgb,
                        mean_rgb=mean_rgb,
                        std_rgb=std_rgb)
                    measurement_image_torch = torch.from_numpy(
                        np.transpose(
                            measurement_image,
                            (2, 0, 1))).float().to(device).unsqueeze(0)
                    measurement_pose = poses[measurement_index]
                    measurement_pose = (np.linalg.inv(measurement_pose)
                                        @ reference_pose)[0:3, :]
                    measurement_pose_torch = torch.from_numpy(
                        measurement_pose).float().to(device).unsqueeze(0)
                    measurement_poses_torch.append(measurement_pose_torch)
                    measurement_images_torch.append(measurement_image_torch)

                camera_k = preprocessor.get_updated_intrinsics()
                camera_k_inv = np.linalg.inv(camera_k)
                camera_k_torch = torch.from_numpy(camera_k).float().to(
                    device).unsqueeze(0)
                camera_k_inv_torch = torch.from_numpy(camera_k_inv).float().to(
                    device).unsqueeze(0)

                inference_timer.record_start_time()
                _, prediction = dpsnet(reference_image_torch,
                                       measurement_images_torch,
                                       measurement_poses_torch, camera_k_torch,
                                       camera_k_inv_torch)
                inference_timer.record_end_time_and_elapsed_time()

                prediction = prediction.cpu().numpy().squeeze()
                reference_depths.append(reference_depth)
                predictions.append(prediction)

                if Config.test_visualize:
                    visualize_predictions(
                        numpy_reference_image=reference_image,
                        numpy_measurement_image=measurement_image,
                        numpy_predicted_depth=prediction,
                        normalization_mean=mean_rgb,
                        normalization_std=std_rgb,
                        normalization_scale=scale_rgb)

        inference_timer.print_statistics()

        system_name = "{}_{}_{}_{}_{}_dpsnet_{}".format(
            keyframing_type, dataset_name, input_image_width,
            input_image_height, n_measurement_frames, extension)

        save_results(predictions=predictions,
                     groundtruths=reference_depths,
                     system_name=system_name,
                     scene_name=scene_name,
                     save_folder=Config.test_result_folder)
def predict():
    print("System: DELTAS")

    device = torch.device('cuda')
    cudnn.benchmark = True

    args, supernet, trinet, depthnet = get_model()

    supernet.eval()
    trinet.eval()
    depthnet.eval()

    scale_rgb = 255.0
    mean_rgb = [0.5, 0.5, 0.5]
    std_rgb = [0.5, 0.5, 0.5]

    dummy_input = torch.empty(size=(1, input_image_height, input_image_width),
                              dtype=torch.float).to(device)

    data_path = Path(Config.test_offline_data_path)
    if Config.test_dataset_name is None:
        keyframe_index_files = sorted(
            (Path(Config.test_offline_data_path) / "indices").files(
                "*nmeas+{}*".format(n_measurement_frames)))
    else:
        keyframe_index_files = sorted(
            (Path(Config.test_offline_data_path) /
             "indices").files("*" + Config.test_dataset_name +
                              "*nmeas+{}*".format(n_measurement_frames)))
    for iteration, keyframe_index_file in enumerate(keyframe_index_files):
        keyframing_type, dataset_name, scene_name, _, _ = keyframe_index_file.split(
            "/")[-1].split("+")

        scene_folder = data_path / dataset_name / scene_name
        print("Predicting for scene:", dataset_name + "-" + scene_name, " - ",
              iteration, "/", len(keyframe_index_files))

        keyframe_index_file_lines = np.loadtxt(keyframe_index_file,
                                               dtype=str,
                                               delimiter="\n")

        K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32)
        poses = np.fromfile(scene_folder / "poses.txt", dtype=float,
                            sep="\n ").reshape((-1, 4, 4))
        image_filenames = sorted((scene_folder / 'images').files("*.png"))
        depth_filenames = sorted((scene_folder / 'depth').files("*.png"))

        input_filenames = []
        for image_filename in image_filenames:
            input_filenames.append(image_filename.split("/")[-1])

        inference_timer = InferenceTimer()

        predictions = []
        reference_depths = []
        with torch.no_grad():
            for i in tqdm(range(0, len(keyframe_index_file_lines))):

                keyframe_index_file_line = keyframe_index_file_lines[i]

                if keyframe_index_file_line == "TRACKING LOST":
                    continue
                else:
                    current_input_filenames = keyframe_index_file_line.split(
                        " ")
                    current_indices = [
                        input_filenames.index(current_input_filenames[x])
                        for x in range(len(current_input_filenames))
                    ]

                reference_index = current_indices[0]
                measurement_indices = current_indices[1:]

                reference_pose = poses[reference_index]
                reference_image = load_image(image_filenames[reference_index])
                reference_depth = cv2.imread(depth_filenames[reference_index],
                                             -1).astype(float) / 1000.0

                preprocessor = PreprocessImage(
                    K=K,
                    old_width=reference_image.shape[1],
                    old_height=reference_image.shape[0],
                    new_width=input_image_width,
                    new_height=input_image_height,
                    distortion_crop=0,
                    perform_crop=False)

                reference_image = preprocessor.apply_rgb(image=reference_image,
                                                         scale_rgb=scale_rgb,
                                                         mean_rgb=mean_rgb,
                                                         std_rgb=std_rgb)
                reference_depth = preprocessor.apply_depth(reference_depth)
                reference_image_torch = torch.from_numpy(
                    np.transpose(reference_image,
                                 (2, 0, 1))).float().to(device).unsqueeze(0)

                # DELTAS ALWAYS REQUIRE A PREDETERMINED NUMBER OF MEASUREMENT FRAMES, SO FAKE IT
                while len(measurement_indices) < n_measurement_frames:
                    measurement_indices.append(measurement_indices[0])

                measurement_poses_torch = []
                measurement_images_torch = []
                for measurement_index in measurement_indices:
                    measurement_image = load_image(
                        image_filenames[measurement_index])
                    measurement_image = preprocessor.apply_rgb(
                        image=measurement_image,
                        scale_rgb=scale_rgb,
                        mean_rgb=mean_rgb,
                        std_rgb=std_rgb)
                    measurement_image_torch = torch.from_numpy(
                        np.transpose(
                            measurement_image,
                            (2, 0, 1))).float().to(device).unsqueeze(0)
                    measurement_pose = poses[measurement_index]
                    measurement_pose = (
                        np.linalg.inv(measurement_pose) @ reference_pose)
                    measurement_pose_torch = torch.from_numpy(
                        measurement_pose).float().to(device).unsqueeze(
                            0).unsqueeze(0)
                    measurement_poses_torch.append(measurement_pose_torch)
                    measurement_images_torch.append(measurement_image_torch)

                K_torch = torch.from_numpy(
                    preprocessor.get_updated_intrinsics()).float().to(
                        device).unsqueeze(0)

                tgt_depth = dummy_input
                ref_depths = [dummy_input for _ in range(n_measurement_frames)]

                inference_timer.record_start_time()
                prediction = predict_for_subsequence(
                    args,
                    supernet,
                    trinet,
                    depthnet,
                    tgt_img=reference_image_torch,
                    tgt_depth=tgt_depth,
                    ref_imgs=measurement_images_torch,
                    ref_depths=ref_depths,
                    poses=measurement_poses_torch,
                    intrinsics=K_torch)
                inference_timer.record_end_time_and_elapsed_time()

                prediction = prediction.cpu().numpy().squeeze()

                reference_depths.append(reference_depth)
                predictions.append(prediction)

                if Config.test_visualize:
                    visualize_predictions(
                        numpy_reference_image=reference_image,
                        numpy_measurement_image=measurement_image,
                        numpy_predicted_depth=prediction,
                        normalization_mean=mean_rgb,
                        normalization_std=std_rgb,
                        normalization_scale=scale_rgb)

        inference_timer.print_statistics()

        system_name = "{}_{}_{}_{}_{}_deltas".format(keyframing_type,
                                                     dataset_name,
                                                     input_image_width,
                                                     input_image_height,
                                                     n_measurement_frames)

        save_results(predictions=predictions,
                     groundtruths=reference_depths,
                     system_name=system_name,
                     scene_name=scene_name,
                     save_folder=Config.test_result_folder)
def predict():
    predict_with_finetuned = True

    if predict_with_finetuned:
        extension = "finetuned"
    else:
        extension = "without_ft"

    input_image_width = 320
    input_image_height = 256

    print("System: GPMVS, is_finetuned = ", predict_with_finetuned)

    device = torch.device('cuda')

    if predict_with_finetuned:
        encoder_weights = torch.load(Path("finetuned-weights").files("*encoder*")[0])
        gp_weights = torch.load(Path("finetuned-weights").files("*gplayer*")[0])
        decoder_weights = torch.load(Path("finetuned-weights").files("*decoder*")[0])
    else:
        encoder_weights = torch.load(Path("original-weights").files("*encoder*")[0])['state_dict']
        gp_weights = torch.load(Path("original-weights").files("*gplayer*")[0])['state_dict']
        decoder_weights = torch.load(Path("original-weights").files("*decoder*")[0])['state_dict']

    encoder = Encoder()
    encoder = torch.nn.DataParallel(encoder)
    encoder.load_state_dict(encoder_weights)
    encoder.eval()
    encoder = encoder.to(device)

    decoder = Decoder()
    decoder = torch.nn.DataParallel(decoder)

    decoder.load_state_dict(decoder_weights)
    decoder.eval()
    decoder = decoder.to(device)

    # load GP values
    gplayer = GPlayer(device=device)
    gplayer.load_state_dict(gp_weights)
    gplayer.eval()
    gamma2 = np.exp(gp_weights['gamma2'][0].item())
    ell = np.exp(gp_weights['ell'][0].item())
    sigma2 = np.exp(gp_weights['sigma2'][0].item())

    warp_grid = get_warp_grid_for_cost_volume_calculation(width=input_image_width,
                                                          height=input_image_height,
                                                          device=device)

    min_depth = 0.5
    max_depth = 50.0
    n_depth_levels = 64

    scale_rgb = 1.0
    mean_rgb = [81.0, 81.0, 81.0]
    std_rgb = [35.0, 35.0, 35.0]

    data_path = Path(Config.test_offline_data_path)
    if Config.test_dataset_name is None:
        keyframe_index_files = sorted((Path(Config.test_offline_data_path) / "indices").files())
    else:
        keyframe_index_files = sorted((Path(Config.test_offline_data_path) / "indices").files("*" + Config.test_dataset_name + "*"))
    for iteration, keyframe_index_file in enumerate(keyframe_index_files[20:]):
        keyframing_type, dataset_name, scene_name, _, n_measurement_frames = keyframe_index_file.split("/")[-1].split("+")

        scene_folder = data_path / dataset_name / scene_name
        print("Predicting for scene:", dataset_name + "-" + scene_name, " - ", iteration, "/", len(keyframe_index_files))

        keyframe_index_file_lines = np.loadtxt(keyframe_index_file, dtype=str, delimiter="\n")

        K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32)
        poses = np.fromfile(scene_folder / "poses.txt", dtype=float, sep="\n ").reshape((-1, 4, 4))
        image_filenames = sorted((scene_folder / 'images').files("*.png"))
        depth_filenames = sorted((scene_folder / 'depth').files("*.png"))

        input_filenames = []
        for image_filename in image_filenames:
            input_filenames.append(image_filename.split("/")[-1])

        lam = np.sqrt(3) / ell
        F = np.array([[0, 1], [-lam ** 2, -2 * lam]])
        Pinf = np.array([[gamma2, 0], [0, gamma2 * lam ** 2]])
        h = np.array([[1], [0]])

        # State mean and covariance
        M = np.zeros((F.shape[0], 512 * 8 * 10))
        P = Pinf

        inference_timer = InferenceTimer()

        previous_index = None
        predictions = []
        reference_depths = []
        with torch.no_grad():
            for i in tqdm(range(0, len(keyframe_index_file_lines))):

                keyframe_index_file_line = keyframe_index_file_lines[i]

                if keyframe_index_file_line == "TRACKING LOST":
                    continue
                else:
                    current_input_filenames = keyframe_index_file_line.split(" ")
                    current_indices = [input_filenames.index(current_input_filenames[x]) for x in range(len(current_input_filenames))]

                reference_index = current_indices[0]
                measurement_indices = current_indices[1:]

                reference_pose = poses[reference_index]
                reference_image = load_image(image_filenames[reference_index])
                reference_depth = cv2.imread(depth_filenames[reference_index], -1).astype(float) / 1000.0

                preprocessor = PreprocessImage(K=K,
                                               old_width=reference_image.shape[1],
                                               old_height=reference_image.shape[0],
                                               new_width=input_image_width,
                                               new_height=input_image_height,
                                               distortion_crop=0,
                                               perform_crop=False)

                reference_image = preprocessor.apply_rgb(image=reference_image,
                                                         scale_rgb=scale_rgb,
                                                         mean_rgb=mean_rgb,
                                                         std_rgb=std_rgb)
                reference_depth = preprocessor.apply_depth(reference_depth)
                reference_image_torch = torch.from_numpy(np.transpose(reference_image, (2, 0, 1))).float().to(device).unsqueeze(0)
                reference_pose_torch = torch.from_numpy(reference_pose).float().to(device).unsqueeze(0)

                measurement_poses_torch = []
                measurement_images_torch = []
                for measurement_index in measurement_indices:
                    measurement_image = load_image(image_filenames[measurement_index])
                    measurement_image = preprocessor.apply_rgb(image=measurement_image,
                                                               scale_rgb=scale_rgb,
                                                               mean_rgb=mean_rgb,
                                                               std_rgb=std_rgb)
                    measurement_image_torch = torch.from_numpy(np.transpose(measurement_image, (2, 0, 1))).float().to(device).unsqueeze(0)
                    measurement_pose_torch = torch.from_numpy(poses[measurement_index]).float().to(device).unsqueeze(0)
                    measurement_images_torch.append(measurement_image_torch)
                    measurement_poses_torch.append(measurement_pose_torch)

                full_K_torch = torch.from_numpy(preprocessor.get_updated_intrinsics()).float().to(device).unsqueeze(0)

                inference_timer.record_start_time()
                cost_volume = cost_volume_fusion(image1=reference_image_torch,
                                                 image2s=measurement_images_torch,
                                                 pose1=reference_pose_torch,
                                                 pose2s=measurement_poses_torch,
                                                 K=full_K_torch,
                                                 warp_grid=warp_grid,
                                                 min_depth=min_depth,
                                                 max_depth=max_depth,
                                                 n_depth_levels=n_depth_levels,
                                                 device=device,
                                                 dot_product=False)

                conv5, conv4, conv3, conv2, conv1 = encoder(reference_image_torch, cost_volume)
                batch, channel, height, width = conv5.size()
                y = np.expand_dims(conv5.cpu().numpy().flatten(), axis=0)

                if previous_index is None:
                    previous_index = measurement_index
                dt, _, _ = pose_distance(poses[reference_index], poses[previous_index])
                A = expm(F * dt)
                Q = Pinf - A.dot(Pinf).dot(A.T)
                M = A.dot(M)
                P = A.dot(P).dot(A.T) + Q

                # Update step
                v = y - h.T.dot(M)
                s = h.T.dot(P).dot(h) + sigma2
                k = P.dot(h) / s
                M += k.dot(v)
                P -= k.dot(h.T).dot(P)

                Z = torch.from_numpy(M[0]).view(batch, channel, height, width).float().to(device)
                Z = torch.nn.functional.relu(Z)

                prediction, _, _, _ = decoder(Z, conv4, conv3, conv2, conv1)
                prediction = torch.clamp(prediction, min=0.02, max=2.0)
                prediction = 1 / prediction

                inference_timer.record_end_time_and_elapsed_time()

                prediction = prediction.cpu().numpy().squeeze()
                previous_index = deepcopy(reference_index)

                reference_depths.append(reference_depth)
                predictions.append(prediction)

                if Config.test_visualize:
                    visualize_predictions(numpy_reference_image=reference_image,
                                          numpy_measurement_image=measurement_image,
                                          numpy_predicted_depth=prediction,
                                          normalization_mean=mean_rgb,
                                          normalization_std=std_rgb,
                                          normalization_scale=scale_rgb)

        inference_timer.print_statistics()

        system_name = "{}_{}_{}_{}_{}_gpmvs_{}".format(keyframing_type,
                                                       dataset_name,
                                                       input_image_width,
                                                       input_image_height,
                                                       n_measurement_frames,
                                                       extension)

        save_results(predictions=predictions,
                     groundtruths=reference_depths,
                     system_name=system_name,
                     scene_name=scene_name,
                     save_folder=Config.test_result_folder)
def predict():
    predict_with_finetuned = True

    if predict_with_finetuned:
        extension = "finetuned"
    else:
        extension = "without_ft"

    input_image_width = 320
    input_image_height = 256

    print("System: MVDEPTHNET, is_finetuned = ", predict_with_finetuned)

    device = torch.device('cuda')
    encoder = Encoder()
    decoder = Decoder()

    if predict_with_finetuned:
        encoder_weights = torch.load(
            Path("finetuned-weights").files("*encoder*")[0])
        decoder_weights = torch.load(
            Path("finetuned-weights").files("*decoder*")[0])
    else:
        mvdepth_weights = torch.load(
            Path("original-weights") / "pretrained_mvdepthnet_combined")
        pretrained_dict = mvdepth_weights['state_dict']
        encoder_weights = encoder.state_dict()
        pretrained_dict_encoder = {
            k: v
            for k, v in pretrained_dict.items() if k in encoder_weights
        }
        encoder_weights.update(pretrained_dict_encoder)
        decoder_weights = decoder.state_dict()
        pretrained_dict_decoder = {
            k: v
            for k, v in pretrained_dict.items() if k in decoder_weights
        }
        decoder_weights.update(pretrained_dict_decoder)

    encoder.load_state_dict(encoder_weights)
    decoder.load_state_dict(decoder_weights)

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    encoder.eval()
    decoder.eval()

    warp_grid = get_warp_grid_for_cost_volume_calculation(
        width=input_image_width, height=input_image_height, device=device)

    min_depth = 0.5
    max_depth = 50.0
    n_depth_levels = 64

    scale_rgb = 1.0
    mean_rgb = [81.0, 81.0, 81.0]
    std_rgb = [35.0, 35.0, 35.0]

    data_path = Path(Config.test_offline_data_path)
    if Config.test_dataset_name is None:
        keyframe_index_files = sorted(
            (Path(Config.test_offline_data_path) / "indices").files())
    else:
        keyframe_index_files = sorted(
            (Path(Config.test_offline_data_path) /
             "indices").files("*" + Config.test_dataset_name + "*"))
    for iteration, keyframe_index_file in enumerate(keyframe_index_files):
        keyframing_type, dataset_name, scene_name, _, n_measurement_frames = keyframe_index_file.split(
            "/")[-1].split("+")

        scene_folder = data_path / dataset_name / scene_name
        print("Predicting for scene:", dataset_name + "-" + scene_name, " - ",
              iteration, "/", len(keyframe_index_files))

        keyframe_index_file_lines = np.loadtxt(keyframe_index_file,
                                               dtype=str,
                                               delimiter="\n")

        K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32)
        poses = np.fromfile(scene_folder / "poses.txt", dtype=float,
                            sep="\n ").reshape((-1, 4, 4))
        image_filenames = sorted((scene_folder / 'images').files("*.png"))
        depth_filenames = sorted((scene_folder / 'depth').files("*.png"))

        input_filenames = []
        for image_filename in image_filenames:
            input_filenames.append(image_filename.split("/")[-1])

        inference_timer = InferenceTimer()

        predictions = []
        reference_depths = []
        with torch.no_grad():
            for i in tqdm(range(0, len(keyframe_index_file_lines))):

                keyframe_index_file_line = keyframe_index_file_lines[i]

                if keyframe_index_file_line == "TRACKING LOST":
                    continue
                else:
                    current_input_filenames = keyframe_index_file_line.split(
                        " ")
                    current_indices = [
                        input_filenames.index(current_input_filenames[x])
                        for x in range(len(current_input_filenames))
                    ]

                reference_index = current_indices[0]
                measurement_indices = current_indices[1:]

                reference_pose = poses[reference_index]
                reference_image = load_image(image_filenames[reference_index])
                reference_depth = cv2.imread(depth_filenames[reference_index],
                                             -1).astype(float) / 1000.0

                preprocessor = PreprocessImage(
                    K=K,
                    old_width=reference_image.shape[1],
                    old_height=reference_image.shape[0],
                    new_width=input_image_width,
                    new_height=input_image_height,
                    distortion_crop=0,
                    perform_crop=False)

                reference_image = preprocessor.apply_rgb(image=reference_image,
                                                         scale_rgb=scale_rgb,
                                                         mean_rgb=mean_rgb,
                                                         std_rgb=std_rgb)
                reference_depth = preprocessor.apply_depth(reference_depth)
                reference_image_torch = torch.from_numpy(
                    np.transpose(reference_image,
                                 (2, 0, 1))).float().to(device).unsqueeze(0)
                reference_pose_torch = torch.from_numpy(
                    reference_pose).float().to(device).unsqueeze(0)

                measurement_poses_torch = []
                measurement_images_torch = []
                for measurement_index in measurement_indices:
                    measurement_image = load_image(
                        image_filenames[measurement_index])
                    measurement_image = preprocessor.apply_rgb(
                        image=measurement_image,
                        scale_rgb=scale_rgb,
                        mean_rgb=mean_rgb,
                        std_rgb=std_rgb)
                    measurement_image_torch = torch.from_numpy(
                        np.transpose(
                            measurement_image,
                            (2, 0, 1))).float().to(device).unsqueeze(0)
                    measurement_pose_torch = torch.from_numpy(
                        poses[measurement_index]).float().to(device).unsqueeze(
                            0)
                    measurement_images_torch.append(measurement_image_torch)
                    measurement_poses_torch.append(measurement_pose_torch)

                full_K_torch = torch.from_numpy(
                    preprocessor.get_updated_intrinsics()).float().to(
                        device).unsqueeze(0)

                inference_timer.record_start_time()
                cost_volume = cost_volume_fusion(
                    image1=reference_image_torch,
                    image2s=measurement_images_torch,
                    pose1=reference_pose_torch,
                    pose2s=measurement_poses_torch,
                    K=full_K_torch,
                    warp_grid=warp_grid,
                    min_depth=min_depth,
                    max_depth=max_depth,
                    n_depth_levels=n_depth_levels,
                    device=device,
                    dot_product=False)

                conv5, conv4, conv3, conv2, conv1 = encoder(
                    reference_image_torch, cost_volume)
                prediction, _, _, _ = decoder(conv5, conv4, conv3, conv2,
                                              conv1)
                prediction = torch.clamp(prediction, min=0.02, max=2.0)
                prediction = 1 / prediction

                inference_timer.record_end_time_and_elapsed_time()

                prediction = prediction.cpu().numpy().squeeze()
                reference_depths.append(reference_depth)
                predictions.append(prediction)

                if Config.test_visualize:
                    visualize_predictions(
                        numpy_reference_image=reference_image,
                        numpy_measurement_image=measurement_image,
                        numpy_predicted_depth=prediction,
                        normalization_mean=mean_rgb,
                        normalization_std=std_rgb,
                        normalization_scale=scale_rgb)

        inference_timer.print_statistics()

        system_name = "{}_{}_{}_{}_{}_mvdepthnet_{}".format(
            keyframing_type, dataset_name, input_image_width,
            input_image_height, n_measurement_frames, extension)

        save_results(predictions=predictions,
                     groundtruths=reference_depths,
                     system_name=system_name,
                     scene_name=scene_name,
                     save_folder=Config.test_result_folder)