def __init__(self, checkpoint_path, device="cuda"): self.device = device checkpoint_net = IMIPLightning.load_from_checkpoint( checkpoint_path, strict=False) # calls seed everything checkpoint_net.freeze() self.network = checkpoint_net.to(device=self.device)
import os from argparse import ArgumentParser from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from imipnet.lightning_module import IMIPLightning parser = ArgumentParser() parser = IMIPLightning.add_model_specific_args(parser) args = parser.parse_args([ "--loss", "outlier-balanced-bce-bce-uml", "--n_top_patches", "1", "--eval_set", "kitti-gray-0.5", "--preprocess", "harris" ]) imip_module = IMIPLightning(args) name = imip_module.get_new_run_name() logger = TensorBoardLogger("./runs", name) checkpoint_dir = os.path.join(".", "checkpoints", "simple-conv", name) os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_callback = ModelCheckpoint( filepath=checkpoint_dir, save_last=True, verbose=True, monitor="eval_true_inliers", mode='max', period=0 # don't wait for a new epoch to save a better model )
def main(): parser = ArgumentParser() parser.add_argument('--data_root', default="./data") parser.add_argument("--output_dir", type=str, default="./example_matches") params = parser.parse_args() preprocessor = preprocess_registry["center"]() os.makedirs(params.output_dir, exist_ok=True) for model_key, dataset_constructor_list in model_dataset_map.items(): snapshot_dir = snapshots_dirs[model_key] checkpoints = os.listdir(snapshot_dir) checkpoint_path = os.path.join(snapshot_dir, [ x for x in checkpoints if x.startswith("epoch=") and x.endswith(".ckpt") ][0]) checkpoint_net = IMIPLightning.load_from_checkpoint( checkpoint_path, strict=False) # calls seed everything checkpoint_net = checkpoint_net.to(device="cuda") checkpoint_net.freeze() for dataset_constructor in dataset_constructor_list: dataset = dataset_constructor(params.data_root) for i in range(100): pair = dataset[np.random.randint(0, len(dataset))] image_1_torch = preprocessor( load_image_for_torch(pair.image_1, device="cuda")) image_2_torch = preprocessor( load_image_for_torch(pair.image_2, device="cuda")) image_1_corrs = checkpoint_net.network.extract_keypoints( image_1_torch)[0] image_2_corrs = checkpoint_net.network.extract_keypoints( image_2_torch)[0] batched_corrs = torch.stack((image_1_corrs, image_2_corrs), dim=0).cpu() # Filter out invalid correspondences ground_truth_corrs, tracked_idx = pair.correspondences( batched_corrs[0].numpy(), inverse=False) batched_corrs = batched_corrs[:, :, tracked_idx] valid_idx = np.linalg.norm( ground_truth_corrs - batched_corrs[1].numpy(), ord=2, axis=0) < 1 batched_corrs = batched_corrs[:, :, valid_idx] batched_corrs = batched_corrs.round() anchor_keypoints = [ cv2.KeyPoint(int(batched_corrs[0][0][i]), int(batched_corrs[0][1][i]), 1) for i in range(batched_corrs.shape[2]) ] corr_keypoints = [ cv2.KeyPoint(int(batched_corrs[1][0][i]), int(batched_corrs[1][1][i]), 1) for i in range(batched_corrs.shape[2]) ] matches = [ cv2.DMatch(i, i, 0.0) for i in range(len(anchor_keypoints)) ] image_1 = cv2.cvtColor(pair.image_1, cv2.COLOR_RGB2BGR) image_2 = cv2.cvtColor(pair.image_2, cv2.COLOR_RGB2BGR) match_img_1 = cv2.drawMatches(image_1, anchor_keypoints, image_2, corr_keypoints, matches, None) pair_name = checkpoint_net.get_name( ) + "_" + pair.name.replace("/", "_") + ".png" cv2.imwrite(os.path.join(params.output_dir, pair_name), match_img_1) print("Wrote {0}".format( os.path.join(params.output_dir, pair_name)), flush=True)
from pytorch_lightning.utilities import move_data_to_device from imipnet.datasets.shuffle import ShuffledDataset from imipnet.lightning_module import IMIPLightning, test_dataset_registry parser = ArgumentParser() parser.add_argument("checkpoint", type=str) parser.add_argument('test_set', choices=test_dataset_registry.keys()) parser.add_argument('--data_root', default="./data") parser.add_argument('--n_eval_samples', type=int, default=-1) parser.add_argument("--output_dir", type=str, default="./test_results") params = parser.parse_args() run_name = os.path.basename(os.path.dirname(params.checkpoint)) checkpoint_net = IMIPLightning.load_from_checkpoint( params.checkpoint, strict=False) # calls seed everything checkpoint_net.freeze() # override test set if params.test_set is not None: checkpoint_net.hparams.test_set = params.test_set checkpoint_net.hparams.n_eval_samples = params.n_eval_samples checkpoint_net.test_set = ShuffledDataset( test_dataset_registry[params.test_set](params.data_root)) print("Evaluating {} on {}".format(checkpoint_net.get_name(), checkpoint_net.hparams.test_set)) # TODO: load device from params if checkpoint_net.hparams.n_eval_samples > 0: print("Number of samples: {}".format(