Ejemplo n.º 1
0
    def __init__(self,
                 scene_ds,
                 meters,
                 batch_size=64,
                 cache_data=True,
                 n_workers=4,
                 sampler=None):

        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()

        self.scene_ds = scene_ds
        if sampler is None:
            sampler = DistributedSceneSampler(scene_ds,
                                              num_replicas=self.world_size,
                                              rank=self.rank,
                                              shuffle=True)
        dataloader = DataLoader(scene_ds,
                                batch_size=batch_size,
                                num_workers=n_workers,
                                sampler=sampler,
                                collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader

        self.meters = meters
        self.meters = OrderedDict({
            k: v
            for k, v in sorted(self.meters.items(), key=lambda item: item[0])
        })
Ejemplo n.º 2
0
    def __init__(self, scene_ds, batch_size=8, cache_data=False, n_workers=4):
        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()

        sampler = DistributedSceneSampler(scene_ds, num_replicas=self.world_size, rank=self.rank)
        self.sampler = sampler
        dataloader = DataLoader(scene_ds, batch_size=batch_size,
                                num_workers=n_workers,
                                sampler=sampler, collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader
Ejemplo n.º 3
0
    def __init__(self, scene_ds, batch_size=1, cache_data=False, n_workers=4):

        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()

        assert batch_size == 1, 'Multiple view groups not supported for now.'
        sampler = DistributedSceneSampler(scene_ds,
                                          num_replicas=self.world_size,
                                          rank=self.rank)
        self.sampler = sampler
        dataloader = DataLoader(scene_ds,
                                batch_size=batch_size,
                                num_workers=n_workers,
                                sampler=sampler,
                                collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader
Ejemplo n.º 4
0
def run_inference(args):
    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    scene_ds = make_scene_dataset(args.ds_name, n_frames=args.n_frames)

    if args.icp:
        scene_ds.load_depth = args.icp

    # if args.debug and 'tless' in args.ds_name:
    #     # Try to debug ICP on T-LESS ??????
    #     view_id = 142
    #     mask = scene_ds.frame_index['view_id'] == view_id
    #     scene_ds.frame_index = scene_ds.frame_index[mask].reset_index(drop=True)

    #     scene_id = 1
    #     mask = scene_ds.frame_index['scene_id'] == scene_id
    #     scene_ds.frame_index = scene_ds.frame_index[mask].reset_index(drop=True)

    scene_ds_multi = MultiViewWrapper(scene_ds, n_views=args.n_views)

    if args.n_groups is not None:
        scene_ds_multi.frame_index = scene_ds_multi.frame_index[:args.
                                                                n_groups].reset_index(
                                                                    drop=True)

    pred_kwargs = dict()
    pred_runner = BopPredictionRunner(scene_ds_multi,
                                      batch_size=args.pred_bsz,
                                      cache_data=False,
                                      n_workers=args.n_workers)

    detector = load_detector(args.detector_run_id)
    pose_predictor, mesh_db = load_pose_models(
        coarse_run_id=args.coarse_run_id,
        refiner_run_id=args.refiner_run_id,
        n_workers=args.n_workers,
    )

    icp_refiner = None
    if args.icp:
        renderer = pose_predictor.coarse_model.renderer
        icp_refiner = ICPRefiner(
            mesh_db,
            renderer=renderer,
            resolution=pose_predictor.coarse_model.cfg.input_resize)

    mv_predictor = None
    if args.n_views > 1:
        mv_predictor = MultiviewScenePredictor(mesh_db)

    pred_kwargs.update({
        'maskrcnn_detections':
        dict(
            detector=detector,
            pose_predictor=pose_predictor,
            n_coarse_iterations=args.n_coarse_iterations,
            n_refiner_iterations=args.n_refiner_iterations,
            icp_refiner=icp_refiner,
            mv_predictor=mv_predictor,
        )
    })

    all_predictions = dict()
    for pred_prefix, pred_kwargs_n in pred_kwargs.items():
        logger.info(f"Prediction: {pred_prefix}")
        preds = pred_runner.get_predictions(**pred_kwargs_n)
        for preds_name, preds_n in preds.items():
            all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n

    logger.info("Done with inference.")
    torch.distributed.barrier()

    for k, v in all_predictions.items():
        all_predictions[k] = v.gather_distributed(tmp_dir=get_tmp_dir()).cpu()

    if get_rank() == 0:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(exist_ok=True, parents=True)
        logger.info(f'Finished inference on {args.ds_name}')
        results = format_results(all_predictions, dict(), dict())
        torch.save(results, save_dir / 'results.pth.tar')
        (save_dir / 'config.yaml').write_text(yaml.dump(args))
        logger.info(f'Saved predictions in {save_dir}')

    torch.distributed.barrier()
    return
Ejemplo n.º 5
0
def run_detection_eval(args, detector=None):
    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    scene_ds = make_scene_dataset(args.ds_name, n_frames=args.n_frames)

    pred_kwargs = dict()
    pred_runner = DetectionRunner(scene_ds,
                                  batch_size=args.pred_bsz,
                                  cache_data=len(pred_kwargs) > 1,
                                  n_workers=args.n_workers)

    if not args.skip_model_predictions:
        if detector is not None:
            model = detector
        else:
            model = load_detector(args.detector_run_id)

        pred_kwargs.update(
            {'model': dict(detector=model, gt_detections=False)})

    all_predictions = dict()

    if args.external_predictions:
        if 'ycbv' in args.ds_name:
            all_predictions['posecnn'] = load_posecnn_results().cpu()
        elif 'tless' in args.ds_name:
            all_predictions['retinanet/pix2pose'] = load_pix2pose_results(
                all_detections=True).cpu()
        else:
            pass

    for pred_prefix, pred_kwargs_n in pred_kwargs.items():
        logger.info(f"Prediction: {pred_prefix}")
        preds = pred_runner.get_predictions(**pred_kwargs_n)
        for preds_name, preds_n in preds.items():
            all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n

    logger.info("Done with predictions")
    torch.distributed.barrier()

    # Evaluation.
    meters = get_meters(scene_ds)
    logger.info(f"Meters: {meters}")
    eval_runner = DetectionEvaluation(scene_ds,
                                      meters,
                                      batch_size=args.eval_bsz,
                                      cache_data=len(all_predictions) > 1,
                                      n_workers=args.n_workers,
                                      sampler=pred_runner.sampler)

    eval_metrics, eval_dfs = dict(), dict()
    if not args.skip_evaluation:
        for preds_k, preds in all_predictions.items():
            do_eval = True
            if do_eval:
                logger.info(f"Evaluation of predictions: {preds_k}")
                if len(preds) == 0:
                    preds = eval_runner.make_empty_predictions()
                eval_metrics[preds_k], eval_dfs[
                    preds_k] = eval_runner.evaluate(preds)
            else:
                logger.info(f"Skipped: {preds_k}")

    for k, v in all_predictions.items():
        all_predictions[k] = v.gather_distributed(tmp_dir=get_tmp_dir()).cpu()

    results = None
    if get_rank() == 0:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(exist_ok=True, parents=True)
        logger.info(f'Finished evaluation on {args.ds_name}')
        results = format_results(all_predictions, eval_metrics, eval_dfs)
        torch.save(results, save_dir / 'results.pth.tar')
        (save_dir / 'summary.txt').write_text(results.get('summary_txt', ''))
        (save_dir / 'config.yaml').write_text(yaml.dump(args))
        logger.info(f'Saved predictions+metrics in {save_dir}')

    logger.info("Done with evaluation")
    torch.distributed.barrier()
    return results
Ejemplo n.º 6
0
def gather_predictions(all_predictions):
    for k, v in all_predictions.items():
        all_predictions[k] = v.gather_distributed(tmp_dir=get_tmp_dir()).cpu()
    return all_predictions