Beispiel #1
0
def run_detection_eval(args, detector=None):
    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    scene_ds = make_scene_dataset(args.ds_name, n_frames=args.n_frames)

    pred_kwargs = dict()
    pred_runner = DetectionRunner(scene_ds,
                                  batch_size=args.pred_bsz,
                                  cache_data=len(pred_kwargs) > 1,
                                  n_workers=args.n_workers)

    if not args.skip_model_predictions:
        if detector is not None:
            model = detector
        else:
            model = load_detector(args.detector_run_id)

        pred_kwargs.update(
            {'model': dict(detector=model, gt_detections=False)})

    all_predictions = dict()

    if args.external_predictions:
        if 'ycbv' in args.ds_name:
            all_predictions['posecnn'] = load_posecnn_results().cpu()
        elif 'tless' in args.ds_name:
            all_predictions['retinanet/pix2pose'] = load_pix2pose_results(
                all_detections=True).cpu()
        else:
            pass

    for pred_prefix, pred_kwargs_n in pred_kwargs.items():
        logger.info(f"Prediction: {pred_prefix}")
        preds = pred_runner.get_predictions(**pred_kwargs_n)
        for preds_name, preds_n in preds.items():
            all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n

    logger.info("Done with predictions")
    torch.distributed.barrier()

    # Evaluation.
    meters = get_meters(scene_ds)
    logger.info(f"Meters: {meters}")
    eval_runner = DetectionEvaluation(scene_ds,
                                      meters,
                                      batch_size=args.eval_bsz,
                                      cache_data=len(all_predictions) > 1,
                                      n_workers=args.n_workers,
                                      sampler=pred_runner.sampler)

    eval_metrics, eval_dfs = dict(), dict()
    if not args.skip_evaluation:
        for preds_k, preds in all_predictions.items():
            do_eval = True
            if do_eval:
                logger.info(f"Evaluation of predictions: {preds_k}")
                if len(preds) == 0:
                    preds = eval_runner.make_empty_predictions()
                eval_metrics[preds_k], eval_dfs[
                    preds_k] = eval_runner.evaluate(preds)
            else:
                logger.info(f"Skipped: {preds_k}")

    for k, v in all_predictions.items():
        all_predictions[k] = v.gather_distributed(tmp_dir=get_tmp_dir()).cpu()

    results = None
    if get_rank() == 0:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(exist_ok=True, parents=True)
        logger.info(f'Finished evaluation on {args.ds_name}')
        results = format_results(all_predictions, eval_metrics, eval_dfs)
        torch.save(results, save_dir / 'results.pth.tar')
        (save_dir / 'summary.txt').write_text(results.get('summary_txt', ''))
        (save_dir / 'config.yaml').write_text(yaml.dump(args))
        logger.info(f'Saved predictions+metrics in {save_dir}')

    logger.info("Done with evaluation")
    torch.distributed.barrier()
    return results
def make_eval_bundle(args, model_training):
    eval_bundle = dict()
    model_training.cfg = args

    def load_model(run_id):
        if run_id is None:
            return None
        run_dir = EXP_DIR / run_id
        cfg = yaml.load((run_dir / 'config.yaml').read_text(),
                        Loader=yaml.FullLoader)
        cfg = check_update_config(cfg)
        model = create_model_pose(
            cfg,
            renderer=model_training.renderer,
            mesh_db=model_training.mesh_db).cuda().eval()
        ckpt = torch.load(run_dir / 'checkpoint.pth.tar')['state_dict']
        model.load_state_dict(ckpt)
        model.eval()
        model.cfg = cfg
        return model

    if args.train_refiner:
        refiner_model = model_training
        coarse_model = load_model(args.coarse_run_id_for_test)
    elif args.train_coarse:
        coarse_model = model_training
        refiner_model = load_model(args.refiner_run_id_for_test)
    else:
        raise ValueError

    predictor = CoarseRefinePosePredictor(coarse_model=coarse_model,
                                          refiner_model=refiner_model)

    base_pred_kwargs = dict(
        pose_predictor=predictor,
        mv_predictor=None,
        use_gt_detections=False,
        skip_mv=True,
    )
    for ds_name in args.test_ds_names:
        assert ds_name in {'ycbv.test.keyframes', 'tless.primesense.test'}
        scene_ds = make_scene_dataset(ds_name, n_frames=args.n_test_frames)
        logger.info(f'TEST: Loaded {ds_name} with {len(scene_ds)} images.')
        scene_ds_pred = MultiViewWrapper(scene_ds, n_views=1)

        # Predictions
        pred_runner = MultiviewPredictionRunner(
            scene_ds_pred,
            batch_size=1,
            n_workers=args.n_dataloader_workers,
            cache_data=False)
        detections = None
        pred_kwargs = dict()

        if 'tless' in ds_name:
            detections = load_pix2pose_results(
                all_detections=False, remove_incorrect_poses=False).cpu()
            coarse_detections = load_pix2pose_results(
                all_detections=False, remove_incorrect_poses=True).cpu()
            det_k = 'pix2pose_detections'
            coarse_k = 'pix2pose_coarse'

        elif 'ycbv' in ds_name:
            detections = load_posecnn_results().cpu()
            coarse_detections = detections
            det_k = 'posecnn_detections'
            coarse_k = 'posecnn_coarse'

        else:
            raise ValueError(ds_name)

        if refiner_model is not None:
            pred_kwargs.update({
                coarse_k:
                dict(
                    detections=coarse_detections,
                    use_detections_TCO=True,
                    n_coarse_iterations=0,
                    n_refiner_iterations=1,
                    **base_pred_kwargs,
                )
            })

        if coarse_model is not None:
            pred_kwargs.update({
                det_k:
                dict(
                    detections=detections,
                    use_detections_TCO=False,
                    n_coarse_iterations=coarse_model.cfg.n_iterations,
                    n_refiner_iterations=1 if refiner_model is not None else 0,
                    **base_pred_kwargs,
                )
            })

        # Evaluation
        meters = get_pose_meters(scene_ds)
        meters = {k.split('_')[0]: v for k, v in meters.items()}
        mv_group_ids = list(iter(pred_runner.sampler))
        scene_ds_ids = np.concatenate(
            scene_ds_pred.frame_index.loc[mv_group_ids, 'scene_ds_ids'].values)
        sampler = ListSampler(scene_ds_ids)
        eval_runner = PoseEvaluation(scene_ds,
                                     meters,
                                     batch_size=1,
                                     cache_data=True,
                                     n_workers=args.n_dataloader_workers,
                                     sampler=sampler)

        save_dir = Path(args.save_dir) / 'eval' / ds_name
        save_dir.mkdir(exist_ok=True, parents=True)
        eval_bundle[ds_name] = (pred_runner, pred_kwargs, eval_runner,
                                save_dir)
    return eval_bundle