def __init__(self, checkpoint_path, device="cuda"):
        self.device = device

        checkpoint_net = IMIPLightning.load_from_checkpoint(
            checkpoint_path, strict=False)  # calls seed everything
        checkpoint_net.freeze()
        self.network = checkpoint_net.to(device=self.device)
def main():
    parser = ArgumentParser()
    parser.add_argument('--data_root', default="./data")
    parser.add_argument("--output_dir", type=str, default="./example_matches")
    params = parser.parse_args()

    preprocessor = preprocess_registry["center"]()

    os.makedirs(params.output_dir, exist_ok=True)

    for model_key, dataset_constructor_list in model_dataset_map.items():
        snapshot_dir = snapshots_dirs[model_key]
        checkpoints = os.listdir(snapshot_dir)
        checkpoint_path = os.path.join(snapshot_dir, [
            x for x in checkpoints
            if x.startswith("epoch=") and x.endswith(".ckpt")
        ][0])
        checkpoint_net = IMIPLightning.load_from_checkpoint(
            checkpoint_path, strict=False)  # calls seed everything
        checkpoint_net = checkpoint_net.to(device="cuda")
        checkpoint_net.freeze()

        for dataset_constructor in dataset_constructor_list:
            dataset = dataset_constructor(params.data_root)
            for i in range(100):
                pair = dataset[np.random.randint(0, len(dataset))]

                image_1_torch = preprocessor(
                    load_image_for_torch(pair.image_1, device="cuda"))
                image_2_torch = preprocessor(
                    load_image_for_torch(pair.image_2, device="cuda"))
                image_1_corrs = checkpoint_net.network.extract_keypoints(
                    image_1_torch)[0]
                image_2_corrs = checkpoint_net.network.extract_keypoints(
                    image_2_torch)[0]
                batched_corrs = torch.stack((image_1_corrs, image_2_corrs),
                                            dim=0).cpu()

                # Filter out invalid correspondences
                ground_truth_corrs, tracked_idx = pair.correspondences(
                    batched_corrs[0].numpy(), inverse=False)
                batched_corrs = batched_corrs[:, :, tracked_idx]
                valid_idx = np.linalg.norm(
                    ground_truth_corrs - batched_corrs[1].numpy(),
                    ord=2,
                    axis=0) < 1
                batched_corrs = batched_corrs[:, :, valid_idx]

                batched_corrs = batched_corrs.round()
                anchor_keypoints = [
                    cv2.KeyPoint(int(batched_corrs[0][0][i]),
                                 int(batched_corrs[0][1][i]), 1)
                    for i in range(batched_corrs.shape[2])
                ]
                corr_keypoints = [
                    cv2.KeyPoint(int(batched_corrs[1][0][i]),
                                 int(batched_corrs[1][1][i]), 1)
                    for i in range(batched_corrs.shape[2])
                ]
                matches = [
                    cv2.DMatch(i, i, 0.0) for i in range(len(anchor_keypoints))
                ]

                image_1 = cv2.cvtColor(pair.image_1, cv2.COLOR_RGB2BGR)
                image_2 = cv2.cvtColor(pair.image_2, cv2.COLOR_RGB2BGR)

                match_img_1 = cv2.drawMatches(image_1, anchor_keypoints,
                                              image_2, corr_keypoints, matches,
                                              None)
                pair_name = checkpoint_net.get_name(
                ) + "_" + pair.name.replace("/", "_") + ".png"
                cv2.imwrite(os.path.join(params.output_dir, pair_name),
                            match_img_1)
                print("Wrote {0}".format(
                    os.path.join(params.output_dir, pair_name)),
                      flush=True)
Пример #3
0
from pytorch_lightning.utilities import move_data_to_device

from imipnet.datasets.shuffle import ShuffledDataset
from imipnet.lightning_module import IMIPLightning, test_dataset_registry

parser = ArgumentParser()
parser.add_argument("checkpoint", type=str)
parser.add_argument('test_set', choices=test_dataset_registry.keys())
parser.add_argument('--data_root', default="./data")
parser.add_argument('--n_eval_samples', type=int, default=-1)
parser.add_argument("--output_dir", type=str, default="./test_results")

params = parser.parse_args()
run_name = os.path.basename(os.path.dirname(params.checkpoint))

checkpoint_net = IMIPLightning.load_from_checkpoint(
    params.checkpoint, strict=False)  # calls seed everything
checkpoint_net.freeze()

# override test set
if params.test_set is not None:
    checkpoint_net.hparams.test_set = params.test_set
    checkpoint_net.hparams.n_eval_samples = params.n_eval_samples
    checkpoint_net.test_set = ShuffledDataset(
        test_dataset_registry[params.test_set](params.data_root))

print("Evaluating {} on {}".format(checkpoint_net.get_name(),
                                   checkpoint_net.hparams.test_set))

# TODO: load device from params
if checkpoint_net.hparams.n_eval_samples > 0:
    print("Number of samples: {}".format(