示例#1
0
    def gather_distributed(self, tmp_dir):
        tmp_dir = Path(tmp_dir)
        tmp_dir.mkdir(exist_ok=True, parents=True)
        rank, world_size = get_rank(), get_world_size()
        tmp_file_template = (tmp_dir / 'rank={rank}.pth.tar').as_posix()
        if rank > 0:
            tmp_file = tmp_file_template.format(rank=rank)
            torch.save(self.datas, tmp_file)

        if world_size > 1:
            torch.distributed.barrier()

        if rank == 0 and world_size > 1:
            all_datas = self.datas
            for n in range(1, world_size):
                tmp_file = tmp_file_template.format(rank=n)
                datas = torch.load(tmp_file)
                for k in all_datas.keys():
                    all_datas[k].extend(datas.get(k, []))
                Path(tmp_file).unlink()
            self.datas = all_datas

        if world_size > 1:
            torch.distributed.barrier()
        return
示例#2
0
    def __init__(self,
                 scene_ds,
                 meters,
                 batch_size=64,
                 cache_data=True,
                 n_workers=4,
                 sampler=None):

        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()

        self.scene_ds = scene_ds
        if sampler is None:
            sampler = DistributedSceneSampler(scene_ds,
                                              num_replicas=self.world_size,
                                              rank=self.rank,
                                              shuffle=True)
        dataloader = DataLoader(scene_ds,
                                batch_size=batch_size,
                                num_workers=n_workers,
                                sampler=sampler,
                                collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader

        self.meters = meters
        self.meters = OrderedDict({
            k: v
            for k, v in sorted(self.meters.items(), key=lambda item: item[0])
        })
示例#3
0
def run_pred_eval(pred_runner, pred_kwargs, eval_runner, eval_preds=None):
    all_predictions = dict()
    for pred_prefix, pred_kwargs_n in pred_kwargs.items():
        print("Prediction :", pred_prefix)
        preds = pred_runner.get_predictions(**pred_kwargs_n)
        for preds_name, preds_n in preds.items():
            all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n

    all_predictions = OrderedDict({k: v for k, v in sorted(all_predictions.items(), key=lambda item: item[0])})
    eval_metrics, eval_dfs = dict(), dict()

    for preds_k, preds in all_predictions.items():
        print("Evaluation :", preds_k)
        if eval_preds is None or preds_k in eval_preds:
            eval_metrics[preds_k], eval_dfs[preds_k] = eval_runner.evaluate(preds)

    all_predictions = gather_predictions(all_predictions)

    if get_rank() == 0:
        results = format_results(all_predictions,
                                 eval_metrics,
                                 eval_dfs)
    else:
        results = None
    return results
示例#4
0
 def summary(self):
     summary, dfs = dict(), dict()
     for meter_k, meter in sorted(self.meters.items()):
         meter.gather_distributed(tmp_dir=self.tmp_dir)
         if get_rank() == 0 and len(meter.datas) > 0:
             summary_, df_ = meter.summary()
             dfs[meter_k] = df_
             for k, v in summary_.items():
                 summary[meter_k + '/' + k] = v
     return summary, dfs
示例#5
0
    def __init__(self, scene_ds, batch_size=8, cache_data=False, n_workers=4):
        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()

        sampler = DistributedSceneSampler(scene_ds, num_replicas=self.world_size, rank=self.rank)
        self.sampler = sampler
        dataloader = DataLoader(scene_ds, batch_size=batch_size,
                                num_workers=n_workers,
                                sampler=sampler, collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader
    def __init__(self, scene_ds, batch_size=1, cache_data=False, n_workers=4):

        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()

        assert batch_size == 1, 'Multiple view groups not supported for now.'
        sampler = DistributedSceneSampler(scene_ds,
                                          num_replicas=self.world_size,
                                          rank=self.rank)
        self.sampler = sampler
        dataloader = DataLoader(scene_ds,
                                batch_size=batch_size,
                                num_workers=n_workers,
                                sampler=sampler,
                                collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader
示例#7
0
    def gather_distributed(self, tmp_dir=None):
        rank, world_size = get_rank(), get_world_size()
        tmp_file_template = (tmp_dir / 'rank={rank}.pth.tar').as_posix()

        if rank > 0:
            tmp_file = tmp_file_template.format(rank=rank)
            torch.save(self, tmp_file)

        if world_size > 1:
            torch.distributed.barrier()

        datas = [self]
        if rank == 0 and world_size > 1:
            for n in range(1, world_size):
                tmp_file = tmp_file_template.format(rank=n)
                data = torch.load(tmp_file)
                datas.append(data)
                Path(tmp_file).unlink()

        if world_size > 1:
            torch.distributed.barrier()
        return concatenate(datas)
示例#8
0
def run_inference(args):
    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    scene_ds = make_scene_dataset(args.ds_name, n_frames=args.n_frames)

    if args.icp:
        scene_ds.load_depth = args.icp

    # if args.debug and 'tless' in args.ds_name:
    #     # Try to debug ICP on T-LESS ??????
    #     view_id = 142
    #     mask = scene_ds.frame_index['view_id'] == view_id
    #     scene_ds.frame_index = scene_ds.frame_index[mask].reset_index(drop=True)

    #     scene_id = 1
    #     mask = scene_ds.frame_index['scene_id'] == scene_id
    #     scene_ds.frame_index = scene_ds.frame_index[mask].reset_index(drop=True)

    scene_ds_multi = MultiViewWrapper(scene_ds, n_views=args.n_views)

    if args.n_groups is not None:
        scene_ds_multi.frame_index = scene_ds_multi.frame_index[:args.
                                                                n_groups].reset_index(
                                                                    drop=True)

    pred_kwargs = dict()
    pred_runner = BopPredictionRunner(scene_ds_multi,
                                      batch_size=args.pred_bsz,
                                      cache_data=False,
                                      n_workers=args.n_workers)

    detector = load_detector(args.detector_run_id)
    pose_predictor, mesh_db = load_pose_models(
        coarse_run_id=args.coarse_run_id,
        refiner_run_id=args.refiner_run_id,
        n_workers=args.n_workers,
    )

    icp_refiner = None
    if args.icp:
        renderer = pose_predictor.coarse_model.renderer
        icp_refiner = ICPRefiner(
            mesh_db,
            renderer=renderer,
            resolution=pose_predictor.coarse_model.cfg.input_resize)

    mv_predictor = None
    if args.n_views > 1:
        mv_predictor = MultiviewScenePredictor(mesh_db)

    pred_kwargs.update({
        'maskrcnn_detections':
        dict(
            detector=detector,
            pose_predictor=pose_predictor,
            n_coarse_iterations=args.n_coarse_iterations,
            n_refiner_iterations=args.n_refiner_iterations,
            icp_refiner=icp_refiner,
            mv_predictor=mv_predictor,
        )
    })

    all_predictions = dict()
    for pred_prefix, pred_kwargs_n in pred_kwargs.items():
        logger.info(f"Prediction: {pred_prefix}")
        preds = pred_runner.get_predictions(**pred_kwargs_n)
        for preds_name, preds_n in preds.items():
            all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n

    logger.info("Done with inference.")
    torch.distributed.barrier()

    for k, v in all_predictions.items():
        all_predictions[k] = v.gather_distributed(tmp_dir=get_tmp_dir()).cpu()

    if get_rank() == 0:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(exist_ok=True, parents=True)
        logger.info(f'Finished inference on {args.ds_name}')
        results = format_results(all_predictions, dict(), dict())
        torch.save(results, save_dir / 'results.pth.tar')
        (save_dir / 'config.yaml').write_text(yaml.dump(args))
        logger.info(f'Saved predictions in {save_dir}')

    torch.distributed.barrier()
    return
示例#9
0
def run_detection_eval(args, detector=None):
    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    scene_ds = make_scene_dataset(args.ds_name, n_frames=args.n_frames)

    pred_kwargs = dict()
    pred_runner = DetectionRunner(scene_ds,
                                  batch_size=args.pred_bsz,
                                  cache_data=len(pred_kwargs) > 1,
                                  n_workers=args.n_workers)

    if not args.skip_model_predictions:
        if detector is not None:
            model = detector
        else:
            model = load_detector(args.detector_run_id)

        pred_kwargs.update(
            {'model': dict(detector=model, gt_detections=False)})

    all_predictions = dict()

    if args.external_predictions:
        if 'ycbv' in args.ds_name:
            all_predictions['posecnn'] = load_posecnn_results().cpu()
        elif 'tless' in args.ds_name:
            all_predictions['retinanet/pix2pose'] = load_pix2pose_results(
                all_detections=True).cpu()
        else:
            pass

    for pred_prefix, pred_kwargs_n in pred_kwargs.items():
        logger.info(f"Prediction: {pred_prefix}")
        preds = pred_runner.get_predictions(**pred_kwargs_n)
        for preds_name, preds_n in preds.items():
            all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n

    logger.info("Done with predictions")
    torch.distributed.barrier()

    # Evaluation.
    meters = get_meters(scene_ds)
    logger.info(f"Meters: {meters}")
    eval_runner = DetectionEvaluation(scene_ds,
                                      meters,
                                      batch_size=args.eval_bsz,
                                      cache_data=len(all_predictions) > 1,
                                      n_workers=args.n_workers,
                                      sampler=pred_runner.sampler)

    eval_metrics, eval_dfs = dict(), dict()
    if not args.skip_evaluation:
        for preds_k, preds in all_predictions.items():
            do_eval = True
            if do_eval:
                logger.info(f"Evaluation of predictions: {preds_k}")
                if len(preds) == 0:
                    preds = eval_runner.make_empty_predictions()
                eval_metrics[preds_k], eval_dfs[
                    preds_k] = eval_runner.evaluate(preds)
            else:
                logger.info(f"Skipped: {preds_k}")

    for k, v in all_predictions.items():
        all_predictions[k] = v.gather_distributed(tmp_dir=get_tmp_dir()).cpu()

    results = None
    if get_rank() == 0:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(exist_ok=True, parents=True)
        logger.info(f'Finished evaluation on {args.ds_name}')
        results = format_results(all_predictions, eval_metrics, eval_dfs)
        torch.save(results, save_dir / 'results.pth.tar')
        (save_dir / 'summary.txt').write_text(results.get('summary_txt', ''))
        (save_dir / 'config.yaml').write_text(yaml.dump(args))
        logger.info(f'Saved predictions+metrics in {save_dir}')

    logger.info("Done with evaluation")
    torch.distributed.barrier()
    return results
示例#10
0
def main():
    loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
    for logger in loggers:
        if 'cosypose' in logger.name:
            logger.setLevel(logging.DEBUG)

    logger.info("Starting ...")
    init_distributed_mode()

    parser = argparse.ArgumentParser('Evaluation')
    parser.add_argument('--config', default='tless-bop', type=str)
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--job_dir', default='', type=str)
    parser.add_argument('--comment', default='', type=str)
    parser.add_argument('--nviews', dest='n_views', default=1, type=int)
    args = parser.parse_args()

    coarse_run_id = None
    refiner_run_id = None
    n_workers = 8
    n_plotters = 8
    n_views = 1

    n_frames = None
    scene_id = None
    group_id = None
    n_groups = None
    n_views = args.n_views
    skip_mv = args.n_views < 2
    skip_predictions = False

    object_set = 'tless'
    if 'tless' in args.config:
        object_set = 'tless'
        coarse_run_id = 'tless-coarse--10219'
        refiner_run_id = 'tless-refiner--585928'
        n_coarse_iterations = 1
        n_refiner_iterations = 4
    elif 'ycbv' in args.config:
        object_set = 'ycbv'
        refiner_run_id = 'ycbv-refiner-finetune--251020'
        n_coarse_iterations = 0
        n_refiner_iterations = 2
    else:
        raise ValueError(args.config)

    if args.config == 'tless-siso':
        ds_name = 'tless.primesense.test'
        assert n_views == 1
    elif args.config == 'tless-vivo':
        ds_name = 'tless.primesense.test.bop19'
    elif args.config == 'ycbv':
        ds_name = 'ycbv.test.keyframes'
    else:
        raise ValueError(args.config)

    if args.debug:
        if 'tless' in args.config:
            scene_id = None
            group_id = 64
            n_groups = 2
        else:
            scene_id = 48
            n_groups = 2
        n_frames = None
        n_workers = 0
        n_plotters = 0

    n_rand = np.random.randint(1e10)
    save_dir = RESULTS_DIR / f'{args.config}-n_views={n_views}-{args.comment}-{n_rand}'
    logger.info(f"SAVE DIR: {save_dir}")
    logger.info(f"Coarse: {coarse_run_id}")
    logger.info(f"Refiner: {refiner_run_id}")

    # Load dataset
    scene_ds = make_scene_dataset(ds_name)

    if scene_id is not None:
        mask = scene_ds.frame_index['scene_id'] == scene_id
        scene_ds.frame_index = scene_ds.frame_index[mask].reset_index(drop=True)
    if n_frames is not None:
        scene_ds.frame_index = scene_ds.frame_index[mask].reset_index(drop=True)[:n_frames]

    # Predictions
    predictor, mesh_db = load_models(coarse_run_id, refiner_run_id, n_workers=n_plotters, object_set=object_set)

    mv_predictor = MultiviewScenePredictor(mesh_db)

    base_pred_kwargs = dict(
        n_coarse_iterations=n_coarse_iterations,
        n_refiner_iterations=n_refiner_iterations,
        skip_mv=skip_mv,
        pose_predictor=predictor,
        mv_predictor=mv_predictor,
    )

    if skip_predictions:
        pred_kwargs = {}
    elif 'tless' in ds_name:
        pix2pose_detections = load_pix2pose_results(all_detections='bop19' in ds_name).cpu()
        pred_kwargs = {
            'pix2pose_detections': dict(
                detections=pix2pose_detections,
                **base_pred_kwargs
            ),
        }
    elif 'ycbv' in ds_name:
        posecnn_detections = load_posecnn_results()
        pred_kwargs = {
            'posecnn_init': dict(
                detections=posecnn_detections,
                use_detections_TCO=posecnn_detections,
                **base_pred_kwargs
            ),
        }
    else:
        raise ValueError(ds_name)

    scene_ds_pred = MultiViewWrapper(scene_ds, n_views=n_views)

    if group_id is not None:
        mask = scene_ds_pred.frame_index['group_id'] == group_id
        scene_ds_pred.frame_index = scene_ds_pred.frame_index[mask].reset_index(drop=True)
    elif n_groups is not None:
        scene_ds_pred.frame_index = scene_ds_pred.frame_index[:n_groups]

    pred_runner = MultiviewPredictionRunner(
        scene_ds_pred, batch_size=1, n_workers=n_workers,
        cache_data=len(pred_kwargs) > 1)

    all_predictions = dict()
    for pred_prefix, pred_kwargs_n in pred_kwargs.items():
        logger.info(f"Prediction: {pred_prefix}")
        preds = pred_runner.get_predictions(**pred_kwargs_n)
        for preds_name, preds_n in preds.items():
            all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n

    logger.info("Done with predictions")
    torch.distributed.barrier()

    # Evaluation
    predictions_to_evaluate = set()
    if 'ycbv' in ds_name:
        det_key = 'posecnn_init'
        all_predictions['posecnn'] = posecnn_detections
        predictions_to_evaluate.add('posecnn')
    elif 'tless' in ds_name:
        det_key = 'pix2pose_detections'
    else:
        raise ValueError(ds_name)
    predictions_to_evaluate.add(f'{det_key}/refiner/iteration={n_refiner_iterations}')

    if args.n_views > 1:
        for k in [
                # f'ba_input',
                # f'ba_output',
                f'ba_output+all_cand'
        ]:
            predictions_to_evaluate.add(f'{det_key}/{k}')

    all_predictions = OrderedDict({k: v for k, v in sorted(all_predictions.items(), key=lambda item: item[0])})

    # Evaluation.
    meters = get_pose_meters(scene_ds)
    mv_group_ids = list(iter(pred_runner.sampler))
    scene_ds_ids = np.concatenate(scene_ds_pred.frame_index.loc[mv_group_ids, 'scene_ds_ids'].values)
    sampler = ListSampler(scene_ds_ids)
    eval_runner = PoseEvaluation(scene_ds, meters, n_workers=n_workers,
                                 cache_data=True, batch_size=1, sampler=sampler)

    eval_metrics, eval_dfs = dict(), dict()
    for preds_k, preds in all_predictions.items():
        if preds_k in predictions_to_evaluate:
            logger.info(f"Evaluation : {preds_k} (N={len(preds)})")
            if len(preds) == 0:
                preds = eval_runner.make_empty_predictions()
            eval_metrics[preds_k], eval_dfs[preds_k] = eval_runner.evaluate(preds)
            preds.cpu()
        else:
            logger.info(f"Skipped: {preds_k} (N={len(preds)})")

    all_predictions = gather_predictions(all_predictions)

    metrics_to_print = dict()
    if 'ycbv' in ds_name:
        metrics_to_print.update({
            f'posecnn/ADD(-S)_ntop=1_matching=CLASS/AUC/objects/mean': f'PoseCNN/AUC of ADD(-S)',

            f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD(-S)_ntop=1_matching=CLASS/AUC/objects/mean': f'Singleview/AUC of ADD(-S)',
            f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD-S_ntop=1_matching=CLASS/AUC/objects/mean': f'Singleview/AUC of ADD-S',

            f'{det_key}/ba_output+all_cand/ADD(-S)_ntop=1_matching=CLASS/AUC/objects/mean': f'Multiview (n={args.n_views})/AUC of ADD(-S)',
            f'{det_key}/ba_output+all_cand/ADD-S_ntop=1_matching=CLASS/AUC/objects/mean': f'Multiview (n={args.n_views})/AUC of ADD-S',
        })
    elif 'tless' in ds_name:
        metrics_to_print.update({
            f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD-S_ntop=BOP_matching=OVERLAP/AUC/objects/mean': f'Singleview/AUC of ADD-S',
            # f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD-S_ntop=BOP_matching=BOP/0.1d': f'Singleview/ADD-S<0.1d',
            f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD-S_ntop=ALL_matching=BOP/mAP': f'Singleview/mAP@ADD-S<0.1d',


            f'{det_key}/ba_output+all_cand/ADD-S_ntop=BOP_matching=OVERLAP/AUC/objects/mean': f'Multiview (n={args.n_views})/AUC of ADD-S',
            # f'{det_key}/ba_output+all_cand/ADD-S_ntop=BOP_matching=BOP/0.1d': f'Multiview (n={args.n_views})/ADD-S<0.1d',
            f'{det_key}/ba_output+all_cand/ADD-S_ntop=ALL_matching=BOP/mAP': f'Multiview (n={args.n_views}/mAP@ADD-S<0.1d)',
        })
    else:
        raise ValueError

    metrics_to_print.update({
        f'{det_key}/ba_input/ADD-S_ntop=BOP_matching=OVERLAP/norm': f'Multiview before BA/ADD-S (m)',
        f'{det_key}/ba_output/ADD-S_ntop=BOP_matching=OVERLAP/norm': f'Multiview after BA/ADD-S (m)',
    })

    if get_rank() == 0:
        save_dir.mkdir()
        results = format_results(all_predictions, eval_metrics, eval_dfs, print_metrics=False)
        (save_dir / 'full_summary.txt').write_text(results.get('summary_txt', ''))

        full_summary = results['summary']
        summary_txt = 'Results:'
        for k, v in metrics_to_print.items():
            if k in full_summary:
                summary_txt += f"\n{v}: {full_summary[k]}"
        logger.info(f"{'-'*80}")
        logger.info(summary_txt)
        logger.info(f"{'-'*80}")

        torch.save(results, save_dir / 'results.pth.tar')
        (save_dir / 'summary.txt').write_text(summary_txt)
        logger.info(f"Saved: {save_dir}")
示例#11
0
def train_pose(args):
    torch.set_num_threads(1)

    if args.resume_run_id:
        resume_dir = EXP_DIR / args.resume_run_id
        resume_args = yaml.load((resume_dir / 'config.yaml').read_text())
        keep_fields = set([
            'resume_run_id',
            'epoch_size',
        ])
        vars(args).update({
            k: v
            for k, v in vars(resume_args).items() if k not in keep_fields
        })

    args.train_refiner = args.TCO_input_generator == 'gt+noise'
    args.train_coarse = not args.train_refiner
    args.save_dir = EXP_DIR / args.run_id

    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    # Initialize distributed
    device = torch.cuda.current_device()
    init_distributed_mode()
    world_size = get_world_size()
    args.n_gpus = world_size
    args.global_batch_size = world_size * args.batch_size
    logger.info(f'Connection established with {world_size} gpus.')

    # Make train/val datasets
    def make_datasets(dataset_names):
        datasets = []
        for (ds_name, n_repeat) in dataset_names:
            assert 'test' not in ds_name
            ds = make_scene_dataset(ds_name)
            logger.info(f'Loaded {ds_name} with {len(ds)} images.')
            for _ in range(n_repeat):
                datasets.append(ds)
        return ConcatDataset(datasets)

    # tracking dataset
    scene_ds_train = make_datasets(args.train_ds_names)
    scene_ds_val = make_datasets(args.val_ds_names)

    ds_kwargs = dict(
        resize=args.input_resize,
        rgb_augmentation=args.rgb_augmentation,
        background_augmentation=args.background_augmentation,
        min_area=args.min_area,
        gray_augmentation=args.gray_augmentation,
    )
    ds_train = PoseTrackingDataset(scene_ds_train, **ds_kwargs)
    ds_val = PoseTrackingDataset(scene_ds_val, **ds_kwargs)

    train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size)
    ds_iter_train = DataLoader(ds_train,
                               sampler=train_sampler,
                               batch_size=args.batch_size,
                               num_workers=args.n_dataloader_workers,
                               collate_fn=ds_train.collate_fn,
                               drop_last=False,
                               pin_memory=True)
    ds_iter_train = MultiEpochDataLoader(ds_iter_train)

    val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size))
    ds_iter_val = DataLoader(ds_val,
                             sampler=val_sampler,
                             batch_size=args.batch_size,
                             num_workers=args.n_dataloader_workers,
                             collate_fn=ds_val.collate_fn,
                             drop_last=False,
                             pin_memory=True)
    ds_iter_val = MultiEpochDataLoader(ds_iter_val)

    # Make model
    # renderer = BulletBatchRenderer(object_set=args.urdf_ds_name, n_workers=args.n_rendering_workers)
    object_ds = make_object_dataset(args.object_ds_name)
    mesh_db = MeshDataBase.from_object_ds(object_ds).batched(
        n_sym=args.n_symmetries_batch).cuda().float()

    model = create_model_pose_custom(cfg=args, mesh_db=mesh_db).cuda()

    eval_bundle = make_eval_bundle(args, model)

    if args.resume_run_id:
        resume_dir = EXP_DIR / args.resume_run_id
        path = resume_dir / 'checkpoint.pth.tar'
        logger.info(f'Loading checkpoing from {path}')
        save = torch.load(path)
        state_dict = save['state_dict']
        model.load_state_dict(state_dict)
        start_epoch = save['epoch'] + 1
    else:
        start_epoch = 0
    end_epoch = args.n_epochs

    if args.run_id_pretrain is not None:
        pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar'
        logger.info(f'Using pretrained model from {pretrain_path}.')
        model.load_state_dict(torch.load(pretrain_path)['state_dict'])

    # Synchronize models across processes.
    model = sync_model(model)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[device],
                                                      output_device=device)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # Warmup
    if args.n_epochs_warmup == 0:
        lambd = lambda epoch: 1
    else:
        n_batches_warmup = args.n_epochs_warmup * (args.epoch_size //
                                                   args.batch_size)
        lambd = lambda batch: (batch + 1) / n_batches_warmup
    lr_scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambd)
    lr_scheduler_warmup.last_epoch = start_epoch * args.epoch_size // args.batch_size

    # LR schedulers
    # Divide LR by 10 every args.lr_epoch_decay
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=args.lr_epoch_decay,
        gamma=0.1,
    )
    lr_scheduler.last_epoch = start_epoch - 1
    lr_scheduler.step()

    for epoch in range(start_epoch, end_epoch):
        meters_train = defaultdict(lambda: AverageValueMeter())
        meters_val = defaultdict(lambda: AverageValueMeter())
        meters_time = defaultdict(lambda: AverageValueMeter())

        h = functools.partial(h_pose_custom,
                              model=model,
                              cfg=args,
                              n_iterations=args.n_iterations,
                              mesh_db=mesh_db,
                              input_generator=args.TCO_input_generator)

        def train_epoch():
            model.train()
            iterator = tqdm(ds_iter_train, ncols=80)
            t = time.time()
            for n, sample in enumerate(iterator):
                if n > 0:
                    meters_time['data'].add(time.time() - t)

                optimizer.zero_grad()

                t = time.time()
                loss = h(data=sample, meters=meters_train)
                meters_time['forward'].add(time.time() - t)
                iterator.set_postfix(loss=loss.item())
                meters_train['loss_total'].add(loss.item())

                t = time.time()
                loss.backward()
                total_grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    max_norm=args.clip_grad_norm,
                    norm_type=2)
                meters_train['grad_norm'].add(
                    torch.as_tensor(total_grad_norm).item())

                optimizer.step()
                meters_time['backward'].add(time.time() - t)
                meters_time['memory'].add(torch.cuda.max_memory_allocated() /
                                          1024.**2)

                if epoch < args.n_epochs_warmup:
                    lr_scheduler_warmup.step()
                t = time.time()
            if epoch >= args.n_epochs_warmup:
                lr_scheduler.step()

        @torch.no_grad()
        def validation():
            model.eval()
            for sample in tqdm(ds_iter_val, ncols=80):
                loss = h(data=sample, meters=meters_val)
                meters_val['loss_total'].add(loss.item())

        @torch.no_grad()
        def test():
            model.eval()
            return run_eval(eval_bundle, epoch=epoch)

        train_epoch()
        if epoch % args.val_epoch_interval == 0:
            validation()

        test_dict = None
        if epoch % args.test_epoch_interval == 0:
            test_dict = test()

        log_dict = dict()
        log_dict.update({
            'grad_norm':
            meters_train['grad_norm'].mean,
            'grad_norm_std':
            meters_train['grad_norm'].std,
            'learning_rate':
            optimizer.param_groups[0]['lr'],
            'time_forward':
            meters_time['forward'].mean,
            'time_backward':
            meters_time['backward'].mean,
            'time_data':
            meters_time['data'].mean,
            'gpu_memory':
            meters_time['memory'].mean,
            'time':
            time.time(),
            'n_iterations': (epoch + 1) * len(ds_iter_train),
            'n_datas':
            (epoch + 1) * args.global_batch_size * len(ds_iter_train),
        })

        for string, meters in zip(('train', 'val'),
                                  (meters_train, meters_val)):
            for k in dict(meters).keys():
                log_dict[f'{string}_{k}'] = meters[k].mean

        log_dict = reduce_dict(log_dict)
        if get_rank() == 0:
            log(config=args,
                model=model,
                epoch=epoch,
                log_dict=log_dict,
                test_dict=test_dict)
        dist.barrier()
示例#12
0
import os
import torch
from cosypose.utils.distributed import init_distributed_mode, get_world_size, get_tmp_dir, get_rank
from cosypose.utils.logging import get_logger

logger = get_logger(__name__)

if __name__ == '__main__':
    init_distributed_mode()
    proc_id = get_rank()
    n_tasks = get_world_size()
    n_cpus = os.environ.get('N_CPUS', 'not specified')
    logger.info(f'Number of processes (=num GPUs): {n_tasks}')
    logger.info(f'Process ID: {proc_id}')
    logger.info(f'TMP Directory for this job: {get_tmp_dir()}')
    logger.info(f'GPU CUDA ID: {torch.cuda.current_device()}')
    logger.info(f'Max number of CPUs for this process: {n_cpus}')
示例#13
0
def train_detector(args):
    torch.set_num_threads(1)

    if args.resume_run_id:
        resume_dir = EXP_DIR / args.resume_run_id
        resume_args = yaml.load((resume_dir / 'config.yaml').read_text())
        keep_fields = set([
            'resume_run_id',
            'epoch_size',
        ])
        vars(args).update({
            k: v
            for k, v in vars(resume_args).items() if k not in keep_fields
        })

    args = check_update_config(args)
    args.save_dir = EXP_DIR / args.run_id

    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    # Initialize distributed
    device = torch.cuda.current_device()
    init_distributed_mode()
    world_size = get_world_size()
    args.n_gpus = world_size
    args.global_batch_size = world_size * args.batch_size
    logger.info(f'Connection established with {world_size} gpus.')

    # Make train/val datasets
    def make_datasets(dataset_names):
        datasets = []
        all_labels = set()
        for (ds_name, n_repeat) in dataset_names:
            assert 'test' not in ds_name
            ds = make_scene_dataset(ds_name)
            logger.info(f'Loaded {ds_name} with {len(ds)} images.')
            all_labels = all_labels.union(set(ds.all_labels))
            for _ in range(n_repeat):
                datasets.append(ds)
        return ConcatDataset(datasets), all_labels

    scene_ds_train, train_labels = make_datasets(args.train_ds_names)
    scene_ds_val, _ = make_datasets(args.val_ds_names)
    label_to_category_id = dict()
    label_to_category_id['background'] = 0
    for n, label in enumerate(sorted(list(train_labels)), 1):
        label_to_category_id[label] = n
    logger.info(
        f'Training with {len(label_to_category_id)} categories: {label_to_category_id}'
    )
    args.label_to_category_id = label_to_category_id

    ds_kwargs = dict(
        resize=args.input_resize,
        rgb_augmentation=args.rgb_augmentation,
        background_augmentation=args.background_augmentation,
        gray_augmentation=args.gray_augmentation,
        label_to_category_id=label_to_category_id,
    )
    ds_train = DetectionDataset(scene_ds_train, **ds_kwargs)
    ds_val = DetectionDataset(scene_ds_val, **ds_kwargs)

    train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size)
    ds_iter_train = DataLoader(ds_train,
                               sampler=train_sampler,
                               batch_size=args.batch_size,
                               num_workers=args.n_dataloader_workers,
                               collate_fn=collate_fn,
                               drop_last=False,
                               pin_memory=True)
    ds_iter_train = MultiEpochDataLoader(ds_iter_train)

    val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size))
    ds_iter_val = DataLoader(ds_val,
                             sampler=val_sampler,
                             batch_size=args.batch_size,
                             num_workers=args.n_dataloader_workers,
                             collate_fn=collate_fn,
                             drop_last=False,
                             pin_memory=True)
    ds_iter_val = MultiEpochDataLoader(ds_iter_val)

    model = create_model_detector(cfg=args,
                                  n_classes=len(
                                      args.label_to_category_id)).cuda()

    if args.resume_run_id:
        resume_dir = EXP_DIR / args.resume_run_id
        path = resume_dir / 'checkpoint.pth.tar'
        logger.info(f'Loading checkpoing from {path}')
        save = torch.load(path)
        state_dict = save['state_dict']
        model.load_state_dict(state_dict)
        start_epoch = save['epoch'] + 1
    else:
        start_epoch = 0
    end_epoch = args.n_epochs

    if args.run_id_pretrain is not None:
        pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar'
        logger.info(f'Using pretrained model from {pretrain_path}.')
        model.load_state_dict(torch.load(pretrain_path)['state_dict'])
    elif args.pretrain_coco:
        state_dict = load_state_dict_from_url(
            model_urls['maskrcnn_resnet50_fpn_coco'])
        keep = lambda k: 'box_predictor' not in k and 'mask_predictor' not in k
        state_dict = {k: v for k, v in state_dict.items() if keep(k)}
        model.load_state_dict(state_dict, strict=False)
        logger.info(
            'Using model pre-trained on coco. Removed predictor heads.')
    else:
        logger.info('Training MaskRCNN from scratch.')

    # Synchronize models across processes.
    model = sync_model(model)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[device],
                                                      output_device=device)

    # Optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)
    elif args.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    else:
        raise ValueError(f'Unknown optimizer {args.optimizer}')

    # Warmup
    if args.n_epochs_warmup == 0:
        lambd = lambda epoch: 1
    else:
        n_batches_warmup = args.n_epochs_warmup * (args.epoch_size //
                                                   args.batch_size)
        lambd = lambda batch: (batch + 1) / n_batches_warmup
    lr_scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambd)
    lr_scheduler_warmup.last_epoch = start_epoch * args.epoch_size // args.batch_size

    # LR schedulers
    # Divide LR by 10 every args.lr_epoch_decay
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=args.lr_epoch_decay,
        gamma=0.1,
    )
    lr_scheduler.last_epoch = start_epoch - 1
    lr_scheduler.step()

    for epoch in range(start_epoch, end_epoch):
        meters_train = defaultdict(AverageValueMeter)
        meters_val = defaultdict(AverageValueMeter)
        meters_time = defaultdict(AverageValueMeter)

        h = functools.partial(h_maskrcnn, model=model, cfg=args)

        def train_epoch():
            model.train()
            iterator = tqdm(ds_iter_train, ncols=80)
            t = time.time()
            for n, sample in enumerate(iterator):
                if n > 0:
                    meters_time['data'].add(time.time() - t)

                optimizer.zero_grad()

                t = time.time()
                loss = h(data=sample, meters=meters_train)
                meters_time['forward'].add(time.time() - t)
                iterator.set_postfix(loss=loss.item())
                meters_train['loss_total'].add(loss.item())

                t = time.time()
                loss.backward()
                total_grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=np.inf, norm_type=2)
                meters_train['grad_norm'].add(
                    torch.as_tensor(total_grad_norm).item())

                optimizer.step()
                meters_time['backward'].add(time.time() - t)
                meters_time['memory'].add(torch.cuda.max_memory_allocated() /
                                          1024.**2)

                if epoch < args.n_epochs_warmup:
                    lr_scheduler_warmup.step()
                t = time.time()
            if epoch >= args.n_epochs_warmup:
                lr_scheduler.step()

        @torch.no_grad()
        def validation():
            model.train()
            for sample in tqdm(ds_iter_val, ncols=80):
                loss = h(data=sample, meters=meters_val)
                meters_val['loss_total'].add(loss.item())

        train_epoch()
        if epoch % args.val_epoch_interval == 0:
            validation()

        test_dict = None
        if epoch % args.test_epoch_interval == 0:
            model.eval()
            test_dict = run_eval(args, model, epoch)

        log_dict = dict()
        log_dict.update({
            'grad_norm':
            meters_train['grad_norm'].mean,
            'grad_norm_std':
            meters_train['grad_norm'].std,
            'learning_rate':
            optimizer.param_groups[0]['lr'],
            'time_forward':
            meters_time['forward'].mean,
            'time_backward':
            meters_time['backward'].mean,
            'time_data':
            meters_time['data'].mean,
            'gpu_memory':
            meters_time['memory'].mean,
            'time':
            time.time(),
            'n_iterations': (epoch + 1) * len(ds_iter_train),
            'n_datas':
            (epoch + 1) * args.global_batch_size * len(ds_iter_train),
        })

        for string, meters in zip(('train', 'val'),
                                  (meters_train, meters_val)):
            for k in dict(meters).keys():
                log_dict[f'{string}_{k}'] = meters[k].mean

        log_dict = reduce_dict(log_dict)
        if get_rank() == 0:
            log(config=args,
                model=model,
                epoch=epoch,
                log_dict=log_dict,
                test_dict=test_dict)
        dist.barrier()