示例#1
0
def export_descriptor(config, output_dir, args):
    """
    # input 2 images, output keypoints and correspondence
    save prediction:
        pred:
            'image': np(320,240)
            'prob' (keypoints): np (N1, 2)
            'desc': np (N2, 256)
            'warped_image': np(320,240)
            'warped_prob' (keypoints): np (N2, 2)
            'warped_desc': np (N2, 256)
            'homography': np (3,3)
            'matches': np [N3, 4]
    """
    from utils.loader import get_save_path
    from utils.var_dim import squeezeToNumpy

    # basic settings
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logging.info("train on device: %s", device)
    with open(os.path.join(output_dir, "config.yml"), "w") as f:
        yaml.dump(config, f, default_flow_style=False)
    writer = SummaryWriter(getWriterPath(task=args.command, date=True))
    save_path = get_save_path(output_dir)
    save_output = save_path / "../predictions"
    os.makedirs(save_output, exist_ok=True)

    ## parameters
    outputMatches = True
    subpixel = config["model"]["subpixel"]["enable"]
    patch_size = config["model"]["subpixel"]["patch_size"]

    # data loading
    from utils.loader import dataLoader_test as dataLoader
    task = config["data"]["dataset"]
    data = dataLoader(config, dataset=task)
    test_set, test_loader = data["test_set"], data["test_loader"]
    from utils.print_tool import datasize
    datasize(test_loader, config, tag="test")

    # model loading
    from utils.loader import get_module
    Val_model_heatmap = get_module("", config["front_end_model"])
    ## load pretrained
    val_agent = Val_model_heatmap(config["model"], device=device)
    val_agent.loadModel()

    ## tracker
    tracker = PointTracker(max_length=2, nn_thresh=val_agent.nn_thresh)

    ###### check!!!
    count = 0
    for i, sample in tqdm(enumerate(test_loader)):
        img_0, img_1 = sample["image"], sample["warped_image"]

        # first image, no matches
        # img = img_0
        def get_pts_desc_from_agent(val_agent, img, device="cpu"):
            """
            pts: list [numpy (3, N)]
            desc: list [numpy (256, N)]
            """
            heatmap_batch = val_agent.run(
                img.to(device)
            )  # heatmap: numpy [batch, 1, H, W]
            # heatmap to pts
            pts = val_agent.heatmap_to_pts()
            # print("pts: ", pts)
            if subpixel:
                pts = val_agent.soft_argmax_points(pts, patch_size=patch_size)
            # heatmap, pts to desc
            desc_sparse = val_agent.desc_to_sparseDesc()
            # print("pts[0]: ", pts[0].shape, ", desc_sparse[0]: ", desc_sparse[0].shape)
            # print("pts[0]: ", pts[0].shape)
            outs = {"pts": pts[0], "desc": desc_sparse[0]}
            return outs

        def transpose_np_dict(outs):
            for entry in list(outs):
                outs[entry] = outs[entry].transpose()

        outs = get_pts_desc_from_agent(val_agent, img_0, device=device)
        pts, desc = outs["pts"], outs["desc"]  # pts: np [3, N]

        if outputMatches == True:
            tracker.update(pts, desc)

        # save keypoints
        pred = {"image": squeezeToNumpy(img_0)}
        pred.update({"prob": pts.transpose(), "desc": desc.transpose()})

        # second image, output matches
        outs = get_pts_desc_from_agent(val_agent, img_1, device=device)
        pts, desc = outs["pts"], outs["desc"]

        if outputMatches == True:
            tracker.update(pts, desc)

        pred.update({"warped_image": squeezeToNumpy(img_1)})
        # print("total points: ", pts.shape)
        pred.update(
            {
                "warped_prob": pts.transpose(),
                "warped_desc": desc.transpose(),
                "homography": squeezeToNumpy(sample["homography"]),
            }
        )

        if outputMatches == True:
            matches = tracker.get_matches()
            print("matches: ", matches.transpose().shape)
            pred.update({"matches": matches.transpose()})
        print("pts: ", pts.shape, ", desc: ", desc.shape)

        # clean last descriptor
        tracker.clear_desc()

        filename = str(count)
        path = Path(save_output, "{}.npz".format(filename))
        np.savez_compressed(path, **pred)
        # print("save: ", path)
        count += 1
    print("output pairs: ", count)
def export_descriptor(config, output_dir, args):
    '''
    1) input 2 images, output keypoints and correspondence

    :param config:
    :param output_dir:
    :param args:
    :return:
    '''
    # config
    # device = torch.device("cpu")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    logging.info('train on device: %s', device)
    with open(os.path.join(output_dir, 'config.yml'), 'w') as f:
        yaml.dump(config, f, default_flow_style=False)
    writer = SummaryWriter(getWriterPath(task=args.command, date=True))

    ## save data
    from pathlib import Path
    # save_path = save_path_formatter(config, output_dir)
    save_path = Path(output_dir)
    save_output = save_path
    save_output = save_output / 'predictions'
    save_path = save_path / 'checkpoints'
    logging.info('=> will save everything to {}'.format(save_path))
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(save_output, exist_ok=True)

    # data loading
    from utils.loader import dataLoader_test as dataLoader
    data = dataLoader(config, dataset='hpatches')
    test_set, test_loader = data['test_set'], data['test_loader']

    from utils.print_tool import datasize
    datasize(test_loader, config, tag='test')

    from imageio import imread
    def load_as_float(path):
        return imread(path).astype(np.float32) / 255

    def squeezeToNumpy(tensor_arr):
        return tensor_arr.detach().cpu().numpy().squeeze()

    outputMatches = True
    count = 0
    max_length = 5
    method = config['model']['method']
    # tracker = PointTracker(max_length, nn_thresh=fe.nn_thresh)

    # for sample in tqdm(enumerate(test_loader)):
    for i, sample in tqdm(enumerate(test_loader)):

        img_0, img_1 = sample['image'], sample['warped_image']

        imgs_np, imgs_fil = [], []
        # first image, no matches
        imgs_np.append(img_0.numpy().squeeze())
        imgs_np.append(img_1.numpy().squeeze())
        # H, W = img.shape[1], img.shape[2]
        # img = img.view(1,1,H,W)

        ##### add opencv functions here #####
        def classicalDetectors(image, method='sift'):
            """
            # sift keyframe detectors and descriptors
            """
            image = image*255
            round_method = False
            if round_method == True:
                from models.classical_detectors_descriptors import classical_detector_descriptor # with quantization
                points, desc = classical_detector_descriptor(image, **{'method': method})
                y, x = np.where(points)
                # pnts = np.stack((y, x), axis=1)
                pnts = np.stack((x, y), axis=1) # should be (x, y)
                ## collect descriptros
                desc = desc[y, x, :]
            else:
                # sift with subpixel accuracy
                from models.classical_detectors_descriptors import SIFT_det as classical_detector_descriptor
                pnts, desc = classical_detector_descriptor(image, image)

            print("desc shape: ", desc.shape)
            return pnts, desc


        pts_list = []
        pts, desc_1 = classicalDetectors(imgs_np[0], method=method)
        pts_list.append(pts)
        print("total points: ", pts.shape)
        '''
        pts: list [numpy (N, 2)]
        desc: list [numpy (N, 128)]
        '''
        # save keypoints
        pred = {}
        pred.update({
            # 'image_fil': imgs_fil[0],
            'image': imgs_np[0],
        })
        pred.update({'prob': pts,
                     'desc': desc_1})

        # second image, output matches

        pred.update({
            'warped_image': imgs_np[1],
            # 'warped_image_fil': imgs_fil[1],
        })
        pts, desc_2 = classicalDetectors(imgs_np[1], method=method)
        pts_list.append(pts)

        # if outputMatches == True:
        #     tracker.update(pts, desc)
        # pred.update({'matches': matches.transpose()})

        print("total points: ", pts.shape)
        pred.update({'warped_prob': pts,
                     'warped_desc': desc_2,
                     'homography': squeezeToNumpy(sample['homography'])
                     })

        ## get matches
        data = get_sift_match(sift_kps_ii=pts_list[0], sift_des_ii=desc_1, sift_kps_jj=pts_list[1], sift_des_jj=desc_2, if_BF_matcher=True)
        matches = data['match_quality_good']
        print(f"matches: {matches.shape}")
        matches_all = data['match_quality_all']
        pred.update({
            'matches': matches,
            'matches_all': matches_all
        })
        # clean last descriptor
        '''
        pred:
            'image': np(320,240)
            'prob' (keypoints): np (N1, 2)
            'desc': np (N2, 256)
            'warped_image': np(320,240)
            'warped_prob' (keypoints): np (N2, 2)
            'warped_desc': np (N2, 256)
            'homography': np (3,3)
            'matches': np (N3, 4)
        '''

        # save data
        from pathlib import Path
        filename = str(count)
        path = Path(save_output, '{}.npz'.format(filename))
        np.savez_compressed(path, **pred)
        count += 1

    print("output pairs: ", count)
    save_file = save_output / "export.txt"
    with open(save_file, "a") as myfile:
        myfile.write("output pairs: " + str(count) + '\n')
    pass