def predict(): scene_name = Config.test_online_scene_path.split("/")[-1] system_name = "{}_{}_{}_{}_dvmvs_fusionnet_online".format( scene_name, Config.test_image_width, Config.test_image_height, Config.test_n_measurement_frames) print("Predicting with System:", system_name) print("# of Measurement Frames:", Config.test_n_measurement_frames) device = torch.device("cuda") feature_extractor = FeatureExtractor() feature_shrinker = FeatureShrinker() cost_volume_encoder = CostVolumeEncoder() lstm_fusion = LSTMFusion() cost_volume_decoder = CostVolumeDecoder() feature_extractor = feature_extractor.to(device) feature_shrinker = feature_shrinker.to(device) cost_volume_encoder = cost_volume_encoder.to(device) lstm_fusion = lstm_fusion.to(device) cost_volume_decoder = cost_volume_decoder.to(device) model = [ feature_extractor, feature_shrinker, cost_volume_encoder, lstm_fusion, cost_volume_decoder ] for i in range(len(model)): try: checkpoint = sorted(Path("weights").files())[i] weights = torch.load(checkpoint) model[i].load_state_dict(weights) model[i].eval() print("Loaded weights for", checkpoint) except Exception as e: print(e) print("Could not find the checkpoint for module", i) exit(1) feature_extractor = model[0] feature_shrinker = model[1] cost_volume_encoder = model[2] lstm_fusion = model[3] cost_volume_decoder = model[4] warp_grid = get_warp_grid_for_cost_volume_calculation( width=int(Config.test_image_width / 2), height=int(Config.test_image_height / 2), device=device) scale_rgb = 255.0 mean_rgb = [0.485, 0.456, 0.406] std_rgb = [0.229, 0.224, 0.225] min_depth = 0.25 max_depth = 20.0 n_depth_levels = 64 scene_folder = Path(Config.test_online_scene_path) scene = scene_folder.split("/")[-1] print("Predicting for scene:", scene) keyframe_buffer = KeyframeBuffer( buffer_size=Config.test_keyframe_buffer_size, keyframe_pose_distance=Config.test_keyframe_pose_distance, optimal_t_score=Config.test_optimal_t_measure, optimal_R_score=Config.test_optimal_R_measure, store_return_indices=False) K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32) poses = np.fromfile(scene_folder / "poses.txt", dtype=float, sep="\n ").reshape((-1, 4, 4)) image_filenames = sorted((scene_folder / 'images').files("*.png")) depth_filenames = sorted((scene_folder / 'depth').files("*.png")) inference_timer = InferenceTimer() lstm_state = None previous_depth = None previous_pose = None predictions = [] reference_depths = [] with torch.no_grad(): for i in tqdm(range(0, len(poses))): reference_pose = poses[i] reference_image = load_image(image_filenames[i]) reference_depth = cv2.imread(depth_filenames[i], -1).astype(float) / 1000.0 # POLL THE KEYFRAME BUFFER response = keyframe_buffer.try_new_keyframe( reference_pose, reference_image) if response == 0 or response == 2 or response == 4 or response == 5: continue elif response == 3: previous_depth = None previous_pose = None lstm_state = None continue preprocessor = PreprocessImage( K=K, old_width=reference_image.shape[1], old_height=reference_image.shape[0], new_width=Config.test_image_width, new_height=Config.test_image_height, distortion_crop=Config.test_distortion_crop, perform_crop=Config.test_perform_crop) reference_image = preprocessor.apply_rgb(image=reference_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) reference_depth = preprocessor.apply_depth(reference_depth) reference_image_torch = torch.from_numpy( np.transpose(reference_image, (2, 0, 1))).float().to(device).unsqueeze(0) reference_pose_torch = torch.from_numpy(reference_pose).float().to( device).unsqueeze(0) full_K_torch = torch.from_numpy( preprocessor.get_updated_intrinsics()).float().to( device).unsqueeze(0) half_K_torch = full_K_torch.clone().cuda() half_K_torch[:, 0:2, :] = half_K_torch[:, 0:2, :] / 2.0 lstm_K_bottom = full_K_torch.clone().cuda() lstm_K_bottom[:, 0:2, :] = lstm_K_bottom[:, 0:2, :] / 32.0 measurement_poses_torch = [] measurement_images_torch = [] measurement_frames = keyframe_buffer.get_best_measurement_frames( Config.test_n_measurement_frames) for (measurement_pose, measurement_image) in measurement_frames: measurement_image = preprocessor.apply_rgb( image=measurement_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) measurement_image_torch = torch.from_numpy( np.transpose(measurement_image, (2, 0, 1))).float().to(device).unsqueeze(0) measurement_pose_torch = torch.from_numpy( measurement_pose).float().to(device).unsqueeze(0) measurement_images_torch.append(measurement_image_torch) measurement_poses_torch.append(measurement_pose_torch) inference_timer.record_start_time() measurement_feature_halfs = [] for measurement_image_torch in measurement_images_torch: measurement_feature_half, _, _, _ = feature_shrinker( *feature_extractor(measurement_image_torch)) measurement_feature_halfs.append(measurement_feature_half) reference_feature_half, reference_feature_quarter, \ reference_feature_one_eight, reference_feature_one_sixteen = feature_shrinker(*feature_extractor(reference_image_torch)) cost_volume = cost_volume_fusion(image1=reference_feature_half, image2s=measurement_feature_halfs, pose1=reference_pose_torch, pose2s=measurement_poses_torch, K=half_K_torch, warp_grid=warp_grid, min_depth=min_depth, max_depth=max_depth, n_depth_levels=n_depth_levels, device=device, dot_product=True) skip0, skip1, skip2, skip3, bottom = cost_volume_encoder( features_half=reference_feature_half, features_quarter=reference_feature_quarter, features_one_eight=reference_feature_one_eight, features_one_sixteen=reference_feature_one_sixteen, cost_volume=cost_volume) if previous_depth is not None: depth_estimation = get_non_differentiable_rectangle_depth_estimation( reference_pose_torch=reference_pose_torch, measurement_pose_torch=previous_pose, previous_depth_torch=previous_depth, full_K_torch=full_K_torch, half_K_torch=half_K_torch, original_height=Config.test_image_height, original_width=Config.test_image_width) depth_estimation = torch.nn.functional.interpolate( input=depth_estimation, scale_factor=(1.0 / 16.0), mode="nearest") else: depth_estimation = torch.zeros( size=(1, 1, int(Config.test_image_height / 32.0), int(Config.test_image_width / 32.0))).to(device) lstm_state = lstm_fusion(current_encoding=bottom, current_state=lstm_state, previous_pose=previous_pose, current_pose=reference_pose_torch, estimated_current_depth=depth_estimation, camera_matrix=lstm_K_bottom) prediction, _, _, _, _ = cost_volume_decoder( reference_image_torch, skip0, skip1, skip2, skip3, lstm_state[0]) previous_depth = prediction.view(1, 1, Config.test_image_height, Config.test_image_width) previous_pose = reference_pose_torch inference_timer.record_end_time_and_elapsed_time() prediction = prediction.cpu().numpy().squeeze() reference_depths.append(reference_depth) predictions.append(prediction) if Config.test_visualize: visualize_predictions( numpy_reference_image=reference_image, numpy_measurement_image=measurement_image, numpy_predicted_depth=prediction, normalization_mean=mean_rgb, normalization_std=std_rgb, normalization_scale=scale_rgb, depth_multiplier_for_visualization=5000) inference_timer.print_statistics() save_results(predictions=predictions, groundtruths=reference_depths, system_name=system_name, scene_name=scene, save_folder=".")
def predict(): print("System: PAIRNET") device = torch.device("cuda") feature_extractor = FeatureExtractor() feature_shrinker = FeatureShrinker() cost_volume_encoder = CostVolumeEncoder() cost_volume_decoder = CostVolumeDecoder() feature_extractor = feature_extractor.to(device) feature_shrinker = feature_shrinker.to(device) cost_volume_encoder = cost_volume_encoder.to(device) cost_volume_decoder = cost_volume_decoder.to(device) model = [ feature_extractor, feature_shrinker, cost_volume_encoder, cost_volume_decoder ] for i in range(len(model)): try: checkpoint = sorted(Path("weights").files())[i] weights = torch.load(checkpoint) model[i].load_state_dict(weights) model[i].eval() print("Loaded weights for", checkpoint) except Exception as e: print(e) print("Could not find the checkpoint for module", i) exit(1) feature_extractor = model[0] feature_shrinker = model[1] cost_volume_encoder = model[2] cost_volume_decoder = model[3] warp_grid = get_warp_grid_for_cost_volume_calculation( width=int(Config.test_image_width / 2), height=int(Config.test_image_height / 2), device=device) scale_rgb = 255.0 mean_rgb = [0.485, 0.456, 0.406] std_rgb = [0.229, 0.224, 0.225] min_depth = 0.25 max_depth = 20.0 n_depth_levels = 64 data_path = Path(Config.test_offline_data_path) if Config.test_dataset_name is None: keyframe_index_files = sorted( (Path(Config.test_offline_data_path) / "indices").files()) else: keyframe_index_files = sorted( (Path(Config.test_offline_data_path) / "indices").files("*" + Config.test_dataset_name + "*")) for iteration, keyframe_index_file in enumerate(keyframe_index_files): keyframing_type, dataset_name, scene_name, _, n_measurement_frames = keyframe_index_file.split( "/")[-1].split("+") scene_folder = data_path / dataset_name / scene_name print("Predicting for scene:", dataset_name + "-" + scene_name, " - ", iteration, "/", len(keyframe_index_files)) keyframe_index_file_lines = np.loadtxt(keyframe_index_file, dtype=str, delimiter="\n") K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32) poses = np.fromfile(scene_folder / "poses.txt", dtype=float, sep="\n ").reshape((-1, 4, 4)) image_filenames = sorted((scene_folder / 'images').files("*.png")) depth_filenames = sorted((scene_folder / 'depth').files("*.png")) input_filenames = [] for image_filename in image_filenames: input_filenames.append(image_filename.split("/")[-1]) inference_timer = InferenceTimer() predictions = [] reference_depths = [] with torch.no_grad(): for i in tqdm(range(0, len(keyframe_index_file_lines))): keyframe_index_file_line = keyframe_index_file_lines[i] if keyframe_index_file_line == "TRACKING LOST": continue else: current_input_filenames = keyframe_index_file_line.split( " ") current_indices = [ input_filenames.index(current_input_filenames[x]) for x in range(len(current_input_filenames)) ] reference_index = current_indices[0] measurement_indices = current_indices[1:] reference_pose = poses[reference_index] reference_image = load_image(image_filenames[reference_index]) reference_depth = cv2.imread(depth_filenames[reference_index], -1).astype(float) / 1000.0 preprocessor = PreprocessImage( K=K, old_width=reference_image.shape[1], old_height=reference_image.shape[0], new_width=Config.test_image_width, new_height=Config.test_image_height, distortion_crop=Config.test_distortion_crop, perform_crop=Config.test_perform_crop) reference_image = preprocessor.apply_rgb(image=reference_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) reference_depth = preprocessor.apply_depth(reference_depth) reference_image_torch = torch.from_numpy( np.transpose(reference_image, (2, 0, 1))).float().to(device).unsqueeze(0) reference_pose_torch = torch.from_numpy( reference_pose).float().to(device).unsqueeze(0) measurement_poses_torch = [] measurement_images_torch = [] for measurement_index in measurement_indices: measurement_image = load_image( image_filenames[measurement_index]) measurement_image = preprocessor.apply_rgb( image=measurement_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) measurement_image_torch = torch.from_numpy( np.transpose( measurement_image, (2, 0, 1))).float().to(device).unsqueeze(0) measurement_pose_torch = torch.from_numpy( poses[measurement_index]).float().to(device).unsqueeze( 0) measurement_images_torch.append(measurement_image_torch) measurement_poses_torch.append(measurement_pose_torch) full_K_torch = torch.from_numpy( preprocessor.get_updated_intrinsics()).float().to( device).unsqueeze(0) half_K_torch = full_K_torch.clone().cuda() half_K_torch[:, 0:2, :] = half_K_torch[:, 0:2, :] / 2.0 inference_timer.record_start_time() measurement_feature_halfs = [] for measurement_image_torch in measurement_images_torch: measurement_feature_half, _, _, _ = feature_shrinker( *feature_extractor(measurement_image_torch)) measurement_feature_halfs.append(measurement_feature_half) reference_feature_half, reference_feature_quarter, \ reference_feature_one_eight, reference_feature_one_sixteen = feature_shrinker(*feature_extractor(reference_image_torch)) cost_volume = cost_volume_fusion( image1=reference_feature_half, image2s=measurement_feature_halfs, pose1=reference_pose_torch, pose2s=measurement_poses_torch, K=half_K_torch, warp_grid=warp_grid, min_depth=min_depth, max_depth=max_depth, n_depth_levels=n_depth_levels, device=device, dot_product=True) skip0, skip1, skip2, skip3, bottom = cost_volume_encoder( features_half=reference_feature_half, features_quarter=reference_feature_quarter, features_one_eight=reference_feature_one_eight, features_one_sixteen=reference_feature_one_sixteen, cost_volume=cost_volume) prediction, _, _, _, _ = cost_volume_decoder( reference_image_torch, skip0, skip1, skip2, skip3, bottom) inference_timer.record_end_time_and_elapsed_time() prediction = prediction.cpu().numpy().squeeze() reference_depths.append(reference_depth) predictions.append(prediction) if Config.test_visualize: visualize_predictions( numpy_reference_image=reference_image, numpy_measurement_image=measurement_image, numpy_predicted_depth=prediction, normalization_mean=mean_rgb, normalization_std=std_rgb, normalization_scale=scale_rgb) inference_timer.print_statistics() system_name = "{}_{}_{}_{}_{}_dvmvs_pairnet".format( keyframing_type, dataset_name, Config.test_image_width, Config.test_image_height, n_measurement_frames) save_results(predictions=predictions, groundtruths=reference_depths, system_name=system_name, scene_name=scene_name, save_folder=Config.test_result_folder)
def predict(): predict_with_finetuned = True if predict_with_finetuned: extension = "finetuned" else: extension = "without_ft" input_image_width = 320 input_image_height = 240 print("System: DPSNET, is_finetuned = ", predict_with_finetuned) device = torch.device('cuda') dpsnet = PSNet(64, 0.5) if predict_with_finetuned: weights = torch.load(Path("finetuned-weights").files("*dpsnet*")[0]) else: weights = torch.load( Path("original-weights").files("*dpsnet*")[0])['state_dict'] dpsnet.load_state_dict(weights) dpsnet = dpsnet.to(device) dpsnet.eval() scale_rgb = 255.0 mean_rgb = [0.5, 0.5, 0.5] std_rgb = [0.5, 0.5, 0.5] data_path = Path(Config.test_offline_data_path) if Config.test_dataset_name is None: keyframe_index_files = sorted( (Path(Config.test_offline_data_path) / "indices").files()) else: keyframe_index_files = sorted( (Path(Config.test_offline_data_path) / "indices").files("*" + Config.test_dataset_name + "*")) for iteration, keyframe_index_file in enumerate(keyframe_index_files): keyframing_type, dataset_name, scene_name, _, n_measurement_frames = keyframe_index_file.split( "/")[-1].split("+") scene_folder = data_path / dataset_name / scene_name print("Predicting for scene:", dataset_name + "-" + scene_name, " - ", iteration, "/", len(keyframe_index_files)) keyframe_index_file_lines = np.loadtxt(keyframe_index_file, dtype=str, delimiter="\n") K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32) poses = np.fromfile(scene_folder / "poses.txt", dtype=float, sep="\n ").reshape((-1, 4, 4)) image_filenames = sorted((scene_folder / 'images').files("*.png")) depth_filenames = sorted((scene_folder / 'depth').files("*.png")) input_filenames = [] for image_filename in image_filenames: input_filenames.append(image_filename.split("/")[-1]) inference_timer = InferenceTimer() predictions = [] reference_depths = [] with torch.no_grad(): for i in tqdm(range(0, len(keyframe_index_file_lines))): keyframe_index_file_line = keyframe_index_file_lines[i] if keyframe_index_file_line == "TRACKING LOST": continue else: current_input_filenames = keyframe_index_file_line.split( " ") current_indices = [ input_filenames.index(current_input_filenames[x]) for x in range(len(current_input_filenames)) ] reference_index = current_indices[0] measurement_indices = current_indices[1:] reference_pose = poses[reference_index] reference_image = load_image(image_filenames[reference_index]) reference_depth = cv2.imread(depth_filenames[reference_index], -1).astype(float) / 1000.0 preprocessor = PreprocessImage( K=K, old_width=reference_image.shape[1], old_height=reference_image.shape[0], new_width=input_image_width, new_height=input_image_height, distortion_crop=0, perform_crop=False) reference_image = preprocessor.apply_rgb(image=reference_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) reference_depth = preprocessor.apply_depth(reference_depth) reference_image_torch = torch.from_numpy( np.transpose(reference_image, (2, 0, 1))).float().to(device).unsqueeze(0) measurement_poses_torch = [] measurement_images_torch = [] for measurement_index in measurement_indices: measurement_image = load_image( image_filenames[measurement_index]) measurement_image = preprocessor.apply_rgb( image=measurement_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) measurement_image_torch = torch.from_numpy( np.transpose( measurement_image, (2, 0, 1))).float().to(device).unsqueeze(0) measurement_pose = poses[measurement_index] measurement_pose = (np.linalg.inv(measurement_pose) @ reference_pose)[0:3, :] measurement_pose_torch = torch.from_numpy( measurement_pose).float().to(device).unsqueeze(0) measurement_poses_torch.append(measurement_pose_torch) measurement_images_torch.append(measurement_image_torch) camera_k = preprocessor.get_updated_intrinsics() camera_k_inv = np.linalg.inv(camera_k) camera_k_torch = torch.from_numpy(camera_k).float().to( device).unsqueeze(0) camera_k_inv_torch = torch.from_numpy(camera_k_inv).float().to( device).unsqueeze(0) inference_timer.record_start_time() _, prediction = dpsnet(reference_image_torch, measurement_images_torch, measurement_poses_torch, camera_k_torch, camera_k_inv_torch) inference_timer.record_end_time_and_elapsed_time() prediction = prediction.cpu().numpy().squeeze() reference_depths.append(reference_depth) predictions.append(prediction) if Config.test_visualize: visualize_predictions( numpy_reference_image=reference_image, numpy_measurement_image=measurement_image, numpy_predicted_depth=prediction, normalization_mean=mean_rgb, normalization_std=std_rgb, normalization_scale=scale_rgb) inference_timer.print_statistics() system_name = "{}_{}_{}_{}_{}_dpsnet_{}".format( keyframing_type, dataset_name, input_image_width, input_image_height, n_measurement_frames, extension) save_results(predictions=predictions, groundtruths=reference_depths, system_name=system_name, scene_name=scene_name, save_folder=Config.test_result_folder)
def predict(): print("System: DELTAS") device = torch.device('cuda') cudnn.benchmark = True args, supernet, trinet, depthnet = get_model() supernet.eval() trinet.eval() depthnet.eval() scale_rgb = 255.0 mean_rgb = [0.5, 0.5, 0.5] std_rgb = [0.5, 0.5, 0.5] dummy_input = torch.empty(size=(1, input_image_height, input_image_width), dtype=torch.float).to(device) data_path = Path(Config.test_offline_data_path) if Config.test_dataset_name is None: keyframe_index_files = sorted( (Path(Config.test_offline_data_path) / "indices").files( "*nmeas+{}*".format(n_measurement_frames))) else: keyframe_index_files = sorted( (Path(Config.test_offline_data_path) / "indices").files("*" + Config.test_dataset_name + "*nmeas+{}*".format(n_measurement_frames))) for iteration, keyframe_index_file in enumerate(keyframe_index_files): keyframing_type, dataset_name, scene_name, _, _ = keyframe_index_file.split( "/")[-1].split("+") scene_folder = data_path / dataset_name / scene_name print("Predicting for scene:", dataset_name + "-" + scene_name, " - ", iteration, "/", len(keyframe_index_files)) keyframe_index_file_lines = np.loadtxt(keyframe_index_file, dtype=str, delimiter="\n") K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32) poses = np.fromfile(scene_folder / "poses.txt", dtype=float, sep="\n ").reshape((-1, 4, 4)) image_filenames = sorted((scene_folder / 'images').files("*.png")) depth_filenames = sorted((scene_folder / 'depth').files("*.png")) input_filenames = [] for image_filename in image_filenames: input_filenames.append(image_filename.split("/")[-1]) inference_timer = InferenceTimer() predictions = [] reference_depths = [] with torch.no_grad(): for i in tqdm(range(0, len(keyframe_index_file_lines))): keyframe_index_file_line = keyframe_index_file_lines[i] if keyframe_index_file_line == "TRACKING LOST": continue else: current_input_filenames = keyframe_index_file_line.split( " ") current_indices = [ input_filenames.index(current_input_filenames[x]) for x in range(len(current_input_filenames)) ] reference_index = current_indices[0] measurement_indices = current_indices[1:] reference_pose = poses[reference_index] reference_image = load_image(image_filenames[reference_index]) reference_depth = cv2.imread(depth_filenames[reference_index], -1).astype(float) / 1000.0 preprocessor = PreprocessImage( K=K, old_width=reference_image.shape[1], old_height=reference_image.shape[0], new_width=input_image_width, new_height=input_image_height, distortion_crop=0, perform_crop=False) reference_image = preprocessor.apply_rgb(image=reference_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) reference_depth = preprocessor.apply_depth(reference_depth) reference_image_torch = torch.from_numpy( np.transpose(reference_image, (2, 0, 1))).float().to(device).unsqueeze(0) # DELTAS ALWAYS REQUIRE A PREDETERMINED NUMBER OF MEASUREMENT FRAMES, SO FAKE IT while len(measurement_indices) < n_measurement_frames: measurement_indices.append(measurement_indices[0]) measurement_poses_torch = [] measurement_images_torch = [] for measurement_index in measurement_indices: measurement_image = load_image( image_filenames[measurement_index]) measurement_image = preprocessor.apply_rgb( image=measurement_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) measurement_image_torch = torch.from_numpy( np.transpose( measurement_image, (2, 0, 1))).float().to(device).unsqueeze(0) measurement_pose = poses[measurement_index] measurement_pose = ( np.linalg.inv(measurement_pose) @ reference_pose) measurement_pose_torch = torch.from_numpy( measurement_pose).float().to(device).unsqueeze( 0).unsqueeze(0) measurement_poses_torch.append(measurement_pose_torch) measurement_images_torch.append(measurement_image_torch) K_torch = torch.from_numpy( preprocessor.get_updated_intrinsics()).float().to( device).unsqueeze(0) tgt_depth = dummy_input ref_depths = [dummy_input for _ in range(n_measurement_frames)] inference_timer.record_start_time() prediction = predict_for_subsequence( args, supernet, trinet, depthnet, tgt_img=reference_image_torch, tgt_depth=tgt_depth, ref_imgs=measurement_images_torch, ref_depths=ref_depths, poses=measurement_poses_torch, intrinsics=K_torch) inference_timer.record_end_time_and_elapsed_time() prediction = prediction.cpu().numpy().squeeze() reference_depths.append(reference_depth) predictions.append(prediction) if Config.test_visualize: visualize_predictions( numpy_reference_image=reference_image, numpy_measurement_image=measurement_image, numpy_predicted_depth=prediction, normalization_mean=mean_rgb, normalization_std=std_rgb, normalization_scale=scale_rgb) inference_timer.print_statistics() system_name = "{}_{}_{}_{}_{}_deltas".format(keyframing_type, dataset_name, input_image_width, input_image_height, n_measurement_frames) save_results(predictions=predictions, groundtruths=reference_depths, system_name=system_name, scene_name=scene_name, save_folder=Config.test_result_folder)
def predict(): predict_with_finetuned = True if predict_with_finetuned: extension = "finetuned" else: extension = "without_ft" input_image_width = 320 input_image_height = 256 print("System: GPMVS, is_finetuned = ", predict_with_finetuned) device = torch.device('cuda') if predict_with_finetuned: encoder_weights = torch.load(Path("finetuned-weights").files("*encoder*")[0]) gp_weights = torch.load(Path("finetuned-weights").files("*gplayer*")[0]) decoder_weights = torch.load(Path("finetuned-weights").files("*decoder*")[0]) else: encoder_weights = torch.load(Path("original-weights").files("*encoder*")[0])['state_dict'] gp_weights = torch.load(Path("original-weights").files("*gplayer*")[0])['state_dict'] decoder_weights = torch.load(Path("original-weights").files("*decoder*")[0])['state_dict'] encoder = Encoder() encoder = torch.nn.DataParallel(encoder) encoder.load_state_dict(encoder_weights) encoder.eval() encoder = encoder.to(device) decoder = Decoder() decoder = torch.nn.DataParallel(decoder) decoder.load_state_dict(decoder_weights) decoder.eval() decoder = decoder.to(device) # load GP values gplayer = GPlayer(device=device) gplayer.load_state_dict(gp_weights) gplayer.eval() gamma2 = np.exp(gp_weights['gamma2'][0].item()) ell = np.exp(gp_weights['ell'][0].item()) sigma2 = np.exp(gp_weights['sigma2'][0].item()) warp_grid = get_warp_grid_for_cost_volume_calculation(width=input_image_width, height=input_image_height, device=device) min_depth = 0.5 max_depth = 50.0 n_depth_levels = 64 scale_rgb = 1.0 mean_rgb = [81.0, 81.0, 81.0] std_rgb = [35.0, 35.0, 35.0] data_path = Path(Config.test_offline_data_path) if Config.test_dataset_name is None: keyframe_index_files = sorted((Path(Config.test_offline_data_path) / "indices").files()) else: keyframe_index_files = sorted((Path(Config.test_offline_data_path) / "indices").files("*" + Config.test_dataset_name + "*")) for iteration, keyframe_index_file in enumerate(keyframe_index_files[20:]): keyframing_type, dataset_name, scene_name, _, n_measurement_frames = keyframe_index_file.split("/")[-1].split("+") scene_folder = data_path / dataset_name / scene_name print("Predicting for scene:", dataset_name + "-" + scene_name, " - ", iteration, "/", len(keyframe_index_files)) keyframe_index_file_lines = np.loadtxt(keyframe_index_file, dtype=str, delimiter="\n") K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32) poses = np.fromfile(scene_folder / "poses.txt", dtype=float, sep="\n ").reshape((-1, 4, 4)) image_filenames = sorted((scene_folder / 'images').files("*.png")) depth_filenames = sorted((scene_folder / 'depth').files("*.png")) input_filenames = [] for image_filename in image_filenames: input_filenames.append(image_filename.split("/")[-1]) lam = np.sqrt(3) / ell F = np.array([[0, 1], [-lam ** 2, -2 * lam]]) Pinf = np.array([[gamma2, 0], [0, gamma2 * lam ** 2]]) h = np.array([[1], [0]]) # State mean and covariance M = np.zeros((F.shape[0], 512 * 8 * 10)) P = Pinf inference_timer = InferenceTimer() previous_index = None predictions = [] reference_depths = [] with torch.no_grad(): for i in tqdm(range(0, len(keyframe_index_file_lines))): keyframe_index_file_line = keyframe_index_file_lines[i] if keyframe_index_file_line == "TRACKING LOST": continue else: current_input_filenames = keyframe_index_file_line.split(" ") current_indices = [input_filenames.index(current_input_filenames[x]) for x in range(len(current_input_filenames))] reference_index = current_indices[0] measurement_indices = current_indices[1:] reference_pose = poses[reference_index] reference_image = load_image(image_filenames[reference_index]) reference_depth = cv2.imread(depth_filenames[reference_index], -1).astype(float) / 1000.0 preprocessor = PreprocessImage(K=K, old_width=reference_image.shape[1], old_height=reference_image.shape[0], new_width=input_image_width, new_height=input_image_height, distortion_crop=0, perform_crop=False) reference_image = preprocessor.apply_rgb(image=reference_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) reference_depth = preprocessor.apply_depth(reference_depth) reference_image_torch = torch.from_numpy(np.transpose(reference_image, (2, 0, 1))).float().to(device).unsqueeze(0) reference_pose_torch = torch.from_numpy(reference_pose).float().to(device).unsqueeze(0) measurement_poses_torch = [] measurement_images_torch = [] for measurement_index in measurement_indices: measurement_image = load_image(image_filenames[measurement_index]) measurement_image = preprocessor.apply_rgb(image=measurement_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) measurement_image_torch = torch.from_numpy(np.transpose(measurement_image, (2, 0, 1))).float().to(device).unsqueeze(0) measurement_pose_torch = torch.from_numpy(poses[measurement_index]).float().to(device).unsqueeze(0) measurement_images_torch.append(measurement_image_torch) measurement_poses_torch.append(measurement_pose_torch) full_K_torch = torch.from_numpy(preprocessor.get_updated_intrinsics()).float().to(device).unsqueeze(0) inference_timer.record_start_time() cost_volume = cost_volume_fusion(image1=reference_image_torch, image2s=measurement_images_torch, pose1=reference_pose_torch, pose2s=measurement_poses_torch, K=full_K_torch, warp_grid=warp_grid, min_depth=min_depth, max_depth=max_depth, n_depth_levels=n_depth_levels, device=device, dot_product=False) conv5, conv4, conv3, conv2, conv1 = encoder(reference_image_torch, cost_volume) batch, channel, height, width = conv5.size() y = np.expand_dims(conv5.cpu().numpy().flatten(), axis=0) if previous_index is None: previous_index = measurement_index dt, _, _ = pose_distance(poses[reference_index], poses[previous_index]) A = expm(F * dt) Q = Pinf - A.dot(Pinf).dot(A.T) M = A.dot(M) P = A.dot(P).dot(A.T) + Q # Update step v = y - h.T.dot(M) s = h.T.dot(P).dot(h) + sigma2 k = P.dot(h) / s M += k.dot(v) P -= k.dot(h.T).dot(P) Z = torch.from_numpy(M[0]).view(batch, channel, height, width).float().to(device) Z = torch.nn.functional.relu(Z) prediction, _, _, _ = decoder(Z, conv4, conv3, conv2, conv1) prediction = torch.clamp(prediction, min=0.02, max=2.0) prediction = 1 / prediction inference_timer.record_end_time_and_elapsed_time() prediction = prediction.cpu().numpy().squeeze() previous_index = deepcopy(reference_index) reference_depths.append(reference_depth) predictions.append(prediction) if Config.test_visualize: visualize_predictions(numpy_reference_image=reference_image, numpy_measurement_image=measurement_image, numpy_predicted_depth=prediction, normalization_mean=mean_rgb, normalization_std=std_rgb, normalization_scale=scale_rgb) inference_timer.print_statistics() system_name = "{}_{}_{}_{}_{}_gpmvs_{}".format(keyframing_type, dataset_name, input_image_width, input_image_height, n_measurement_frames, extension) save_results(predictions=predictions, groundtruths=reference_depths, system_name=system_name, scene_name=scene_name, save_folder=Config.test_result_folder)
def predict(): predict_with_finetuned = True if predict_with_finetuned: extension = "finetuned" else: extension = "without_ft" input_image_width = 320 input_image_height = 256 print("System: MVDEPTHNET, is_finetuned = ", predict_with_finetuned) device = torch.device('cuda') encoder = Encoder() decoder = Decoder() if predict_with_finetuned: encoder_weights = torch.load( Path("finetuned-weights").files("*encoder*")[0]) decoder_weights = torch.load( Path("finetuned-weights").files("*decoder*")[0]) else: mvdepth_weights = torch.load( Path("original-weights") / "pretrained_mvdepthnet_combined") pretrained_dict = mvdepth_weights['state_dict'] encoder_weights = encoder.state_dict() pretrained_dict_encoder = { k: v for k, v in pretrained_dict.items() if k in encoder_weights } encoder_weights.update(pretrained_dict_encoder) decoder_weights = decoder.state_dict() pretrained_dict_decoder = { k: v for k, v in pretrained_dict.items() if k in decoder_weights } decoder_weights.update(pretrained_dict_decoder) encoder.load_state_dict(encoder_weights) decoder.load_state_dict(decoder_weights) encoder = encoder.to(device) decoder = decoder.to(device) encoder.eval() decoder.eval() warp_grid = get_warp_grid_for_cost_volume_calculation( width=input_image_width, height=input_image_height, device=device) min_depth = 0.5 max_depth = 50.0 n_depth_levels = 64 scale_rgb = 1.0 mean_rgb = [81.0, 81.0, 81.0] std_rgb = [35.0, 35.0, 35.0] data_path = Path(Config.test_offline_data_path) if Config.test_dataset_name is None: keyframe_index_files = sorted( (Path(Config.test_offline_data_path) / "indices").files()) else: keyframe_index_files = sorted( (Path(Config.test_offline_data_path) / "indices").files("*" + Config.test_dataset_name + "*")) for iteration, keyframe_index_file in enumerate(keyframe_index_files): keyframing_type, dataset_name, scene_name, _, n_measurement_frames = keyframe_index_file.split( "/")[-1].split("+") scene_folder = data_path / dataset_name / scene_name print("Predicting for scene:", dataset_name + "-" + scene_name, " - ", iteration, "/", len(keyframe_index_files)) keyframe_index_file_lines = np.loadtxt(keyframe_index_file, dtype=str, delimiter="\n") K = np.loadtxt(scene_folder / 'K.txt').astype(np.float32) poses = np.fromfile(scene_folder / "poses.txt", dtype=float, sep="\n ").reshape((-1, 4, 4)) image_filenames = sorted((scene_folder / 'images').files("*.png")) depth_filenames = sorted((scene_folder / 'depth').files("*.png")) input_filenames = [] for image_filename in image_filenames: input_filenames.append(image_filename.split("/")[-1]) inference_timer = InferenceTimer() predictions = [] reference_depths = [] with torch.no_grad(): for i in tqdm(range(0, len(keyframe_index_file_lines))): keyframe_index_file_line = keyframe_index_file_lines[i] if keyframe_index_file_line == "TRACKING LOST": continue else: current_input_filenames = keyframe_index_file_line.split( " ") current_indices = [ input_filenames.index(current_input_filenames[x]) for x in range(len(current_input_filenames)) ] reference_index = current_indices[0] measurement_indices = current_indices[1:] reference_pose = poses[reference_index] reference_image = load_image(image_filenames[reference_index]) reference_depth = cv2.imread(depth_filenames[reference_index], -1).astype(float) / 1000.0 preprocessor = PreprocessImage( K=K, old_width=reference_image.shape[1], old_height=reference_image.shape[0], new_width=input_image_width, new_height=input_image_height, distortion_crop=0, perform_crop=False) reference_image = preprocessor.apply_rgb(image=reference_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) reference_depth = preprocessor.apply_depth(reference_depth) reference_image_torch = torch.from_numpy( np.transpose(reference_image, (2, 0, 1))).float().to(device).unsqueeze(0) reference_pose_torch = torch.from_numpy( reference_pose).float().to(device).unsqueeze(0) measurement_poses_torch = [] measurement_images_torch = [] for measurement_index in measurement_indices: measurement_image = load_image( image_filenames[measurement_index]) measurement_image = preprocessor.apply_rgb( image=measurement_image, scale_rgb=scale_rgb, mean_rgb=mean_rgb, std_rgb=std_rgb) measurement_image_torch = torch.from_numpy( np.transpose( measurement_image, (2, 0, 1))).float().to(device).unsqueeze(0) measurement_pose_torch = torch.from_numpy( poses[measurement_index]).float().to(device).unsqueeze( 0) measurement_images_torch.append(measurement_image_torch) measurement_poses_torch.append(measurement_pose_torch) full_K_torch = torch.from_numpy( preprocessor.get_updated_intrinsics()).float().to( device).unsqueeze(0) inference_timer.record_start_time() cost_volume = cost_volume_fusion( image1=reference_image_torch, image2s=measurement_images_torch, pose1=reference_pose_torch, pose2s=measurement_poses_torch, K=full_K_torch, warp_grid=warp_grid, min_depth=min_depth, max_depth=max_depth, n_depth_levels=n_depth_levels, device=device, dot_product=False) conv5, conv4, conv3, conv2, conv1 = encoder( reference_image_torch, cost_volume) prediction, _, _, _ = decoder(conv5, conv4, conv3, conv2, conv1) prediction = torch.clamp(prediction, min=0.02, max=2.0) prediction = 1 / prediction inference_timer.record_end_time_and_elapsed_time() prediction = prediction.cpu().numpy().squeeze() reference_depths.append(reference_depth) predictions.append(prediction) if Config.test_visualize: visualize_predictions( numpy_reference_image=reference_image, numpy_measurement_image=measurement_image, numpy_predicted_depth=prediction, normalization_mean=mean_rgb, normalization_std=std_rgb, normalization_scale=scale_rgb) inference_timer.print_statistics() system_name = "{}_{}_{}_{}_{}_mvdepthnet_{}".format( keyframing_type, dataset_name, input_image_width, input_image_height, n_measurement_frames, extension) save_results(predictions=predictions, groundtruths=reference_depths, system_name=system_name, scene_name=scene_name, save_folder=Config.test_result_folder)