예제 #1
0
def load_models(coarse_run_id, refiner_run_id=None, n_workers=8, object_set='tless'):
    if object_set == 'tless':
        object_ds_name, urdf_ds_name = 'tless.bop', 'tless.cad'
    else:
        object_ds_name, urdf_ds_name = 'ycbv.bop-compat.eval', 'ycbv'

    object_ds = make_object_dataset(object_ds_name)
    mesh_db = MeshDataBase.from_object_ds(object_ds)
    renderer = BulletBatchRenderer(object_set=urdf_ds_name, n_workers=n_workers)
    mesh_db_batched = mesh_db.batched().cuda()

    def load_model(run_id):
        if run_id is None:
            return
        run_dir = EXP_DIR / run_id
        cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader)
        cfg = check_update_config(cfg)
        if cfg.train_refiner:
            model = create_model_refiner(cfg, renderer=renderer, mesh_db=mesh_db_batched)
            ckpt = torch.load(run_dir / 'checkpoint.pth.tar')
        else:
            model = create_model_coarse(cfg, renderer=renderer, mesh_db=mesh_db_batched)
            ckpt = torch.load(run_dir / 'checkpoint.pth.tar')
        ckpt = ckpt['state_dict']
        model.load_state_dict(ckpt)
        model = model.cuda().eval()
        model.cfg = cfg
        return model

    coarse_model = load_model(coarse_run_id)
    refiner_model = load_model(refiner_run_id)
    model = CoarseRefinePosePredictor(coarse_model=coarse_model,
                                      refiner_model=refiner_model)
    return model, mesh_db
예제 #2
0
def make_eval_bundle(args, model_training):
    eval_bundle = dict()
    model_training.cfg = args

    def load_model(run_id):
        if run_id is None:
            return None
        run_dir = EXP_DIR / run_id
        cfg = yaml.load((run_dir / 'config.yaml').read_text(),
                        Loader=yaml.FullLoader)
        cfg = check_update_config(cfg)
        model = create_model_pose(
            cfg,
            renderer=model_training.renderer,
            mesh_db=model_training.mesh_db).cuda().eval()
        ckpt = torch.load(run_dir / 'checkpoint.pth.tar')['state_dict']
        model.load_state_dict(ckpt)
        model.eval()
        model.cfg = cfg
        return model

    if args.train_refiner:
        refiner_model = model_training
        coarse_model = load_model(args.coarse_run_id_for_test)
    elif args.train_coarse:
        coarse_model = model_training
        refiner_model = load_model(args.refiner_run_id_for_test)
    else:
        raise ValueError

    predictor = CoarseRefinePosePredictor(coarse_model=coarse_model,
                                          refiner_model=refiner_model)

    base_pred_kwargs = dict(
        pose_predictor=predictor,
        mv_predictor=None,
        use_gt_detections=False,
        skip_mv=True,
    )
    for ds_name in args.test_ds_names:
        assert ds_name in {'ycbv.test.keyframes', 'tless.primesense.test'}
        scene_ds = make_scene_dataset(ds_name, n_frames=args.n_test_frames)
        logger.info(f'TEST: Loaded {ds_name} with {len(scene_ds)} images.')
        scene_ds_pred = MultiViewWrapper(scene_ds, n_views=1)

        # Predictions
        pred_runner = MultiviewPredictionRunner(
            scene_ds_pred,
            batch_size=1,
            n_workers=args.n_dataloader_workers,
            cache_data=False)
        detections = None
        pred_kwargs = dict()

        if 'tless' in ds_name:
            detections = load_pix2pose_results(
                all_detections=False, remove_incorrect_poses=False).cpu()
            coarse_detections = load_pix2pose_results(
                all_detections=False, remove_incorrect_poses=True).cpu()
            det_k = 'pix2pose_detections'
            coarse_k = 'pix2pose_coarse'

        elif 'ycbv' in ds_name:
            detections = load_posecnn_results().cpu()
            coarse_detections = detections
            det_k = 'posecnn_detections'
            coarse_k = 'posecnn_coarse'

        else:
            raise ValueError(ds_name)

        if refiner_model is not None:
            pred_kwargs.update({
                coarse_k:
                dict(
                    detections=coarse_detections,
                    use_detections_TCO=True,
                    n_coarse_iterations=0,
                    n_refiner_iterations=1,
                    **base_pred_kwargs,
                )
            })

        if coarse_model is not None:
            pred_kwargs.update({
                det_k:
                dict(
                    detections=detections,
                    use_detections_TCO=False,
                    n_coarse_iterations=coarse_model.cfg.n_iterations,
                    n_refiner_iterations=1 if refiner_model is not None else 0,
                    **base_pred_kwargs,
                )
            })

        # Evaluation
        meters = get_pose_meters(scene_ds)
        meters = {k.split('_')[0]: v for k, v in meters.items()}
        mv_group_ids = list(iter(pred_runner.sampler))
        scene_ds_ids = np.concatenate(
            scene_ds_pred.frame_index.loc[mv_group_ids, 'scene_ds_ids'].values)
        sampler = ListSampler(scene_ds_ids)
        eval_runner = PoseEvaluation(scene_ds,
                                     meters,
                                     batch_size=1,
                                     cache_data=True,
                                     n_workers=args.n_dataloader_workers,
                                     sampler=sampler)

        save_dir = Path(args.save_dir) / 'eval' / ds_name
        save_dir.mkdir(exist_ok=True, parents=True)
        eval_bundle[ds_name] = (pred_runner, pred_kwargs, eval_runner,
                                save_dir)
    return eval_bundle