def load_models(coarse_run_id, refiner_run_id=None, n_workers=8, object_set='tless'): if object_set == 'tless': object_ds_name, urdf_ds_name = 'tless.bop', 'tless.cad' else: object_ds_name, urdf_ds_name = 'ycbv.bop-compat.eval', 'ycbv' object_ds = make_object_dataset(object_ds_name) mesh_db = MeshDataBase.from_object_ds(object_ds) renderer = BulletBatchRenderer(object_set=urdf_ds_name, n_workers=n_workers) mesh_db_batched = mesh_db.batched().cuda() def load_model(run_id): if run_id is None: return run_dir = EXP_DIR / run_id cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader) cfg = check_update_config(cfg) if cfg.train_refiner: model = create_model_refiner(cfg, renderer=renderer, mesh_db=mesh_db_batched) ckpt = torch.load(run_dir / 'checkpoint.pth.tar') else: model = create_model_coarse(cfg, renderer=renderer, mesh_db=mesh_db_batched) ckpt = torch.load(run_dir / 'checkpoint.pth.tar') ckpt = ckpt['state_dict'] model.load_state_dict(ckpt) model = model.cuda().eval() model.cfg = cfg return model coarse_model = load_model(coarse_run_id) refiner_model = load_model(refiner_run_id) model = CoarseRefinePosePredictor(coarse_model=coarse_model, refiner_model=refiner_model) return model, mesh_db
def make_eval_bundle(args, model_training): eval_bundle = dict() model_training.cfg = args def load_model(run_id): if run_id is None: return None run_dir = EXP_DIR / run_id cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader) cfg = check_update_config(cfg) model = create_model_pose( cfg, renderer=model_training.renderer, mesh_db=model_training.mesh_db).cuda().eval() ckpt = torch.load(run_dir / 'checkpoint.pth.tar')['state_dict'] model.load_state_dict(ckpt) model.eval() model.cfg = cfg return model if args.train_refiner: refiner_model = model_training coarse_model = load_model(args.coarse_run_id_for_test) elif args.train_coarse: coarse_model = model_training refiner_model = load_model(args.refiner_run_id_for_test) else: raise ValueError predictor = CoarseRefinePosePredictor(coarse_model=coarse_model, refiner_model=refiner_model) base_pred_kwargs = dict( pose_predictor=predictor, mv_predictor=None, use_gt_detections=False, skip_mv=True, ) for ds_name in args.test_ds_names: assert ds_name in {'ycbv.test.keyframes', 'tless.primesense.test'} scene_ds = make_scene_dataset(ds_name, n_frames=args.n_test_frames) logger.info(f'TEST: Loaded {ds_name} with {len(scene_ds)} images.') scene_ds_pred = MultiViewWrapper(scene_ds, n_views=1) # Predictions pred_runner = MultiviewPredictionRunner( scene_ds_pred, batch_size=1, n_workers=args.n_dataloader_workers, cache_data=False) detections = None pred_kwargs = dict() if 'tless' in ds_name: detections = load_pix2pose_results( all_detections=False, remove_incorrect_poses=False).cpu() coarse_detections = load_pix2pose_results( all_detections=False, remove_incorrect_poses=True).cpu() det_k = 'pix2pose_detections' coarse_k = 'pix2pose_coarse' elif 'ycbv' in ds_name: detections = load_posecnn_results().cpu() coarse_detections = detections det_k = 'posecnn_detections' coarse_k = 'posecnn_coarse' else: raise ValueError(ds_name) if refiner_model is not None: pred_kwargs.update({ coarse_k: dict( detections=coarse_detections, use_detections_TCO=True, n_coarse_iterations=0, n_refiner_iterations=1, **base_pred_kwargs, ) }) if coarse_model is not None: pred_kwargs.update({ det_k: dict( detections=detections, use_detections_TCO=False, n_coarse_iterations=coarse_model.cfg.n_iterations, n_refiner_iterations=1 if refiner_model is not None else 0, **base_pred_kwargs, ) }) # Evaluation meters = get_pose_meters(scene_ds) meters = {k.split('_')[0]: v for k, v in meters.items()} mv_group_ids = list(iter(pred_runner.sampler)) scene_ds_ids = np.concatenate( scene_ds_pred.frame_index.loc[mv_group_ids, 'scene_ds_ids'].values) sampler = ListSampler(scene_ds_ids) eval_runner = PoseEvaluation(scene_ds, meters, batch_size=1, cache_data=True, n_workers=args.n_dataloader_workers, sampler=sampler) save_dir = Path(args.save_dir) / 'eval' / ds_name save_dir.mkdir(exist_ok=True, parents=True) eval_bundle[ds_name] = (pred_runner, pred_kwargs, eval_runner, save_dir) return eval_bundle