Esempio n. 1
0
def load_models(coarse_run_id, refiner_run_id=None, n_workers=8, object_set='tless'):
    if object_set == 'tless':
        object_ds_name, urdf_ds_name = 'tless.bop', 'tless.cad'
    else:
        object_ds_name, urdf_ds_name = 'ycbv.bop-compat.eval', 'ycbv'

    object_ds = make_object_dataset(object_ds_name)
    mesh_db = MeshDataBase.from_object_ds(object_ds)
    renderer = BulletBatchRenderer(object_set=urdf_ds_name, n_workers=n_workers)
    mesh_db_batched = mesh_db.batched().cuda()

    def load_model(run_id):
        if run_id is None:
            return
        run_dir = EXP_DIR / run_id
        cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader)
        cfg = check_update_config(cfg)
        if cfg.train_refiner:
            model = create_model_refiner(cfg, renderer=renderer, mesh_db=mesh_db_batched)
            ckpt = torch.load(run_dir / 'checkpoint.pth.tar')
        else:
            model = create_model_coarse(cfg, renderer=renderer, mesh_db=mesh_db_batched)
            ckpt = torch.load(run_dir / 'checkpoint.pth.tar')
        ckpt = ckpt['state_dict']
        model.load_state_dict(ckpt)
        model = model.cuda().eval()
        model.cfg = cfg
        return model

    coarse_model = load_model(coarse_run_id)
    refiner_model = load_model(refiner_run_id)
    model = CoarseRefinePosePredictor(coarse_model=coarse_model,
                                      refiner_model=refiner_model)
    return model, mesh_db
def main():
    parser = argparse.ArgumentParser(
        'CosyPose multi-view reconstruction for a custom scenario')
    parser.add_argument(
        '--scenario',
        default='',
        type=str,
        help=
        'Id of the scenario, matching directory must be in local_data/scenarios'
    )
    parser.add_argument('--sv_score_th',
                        default=0.3,
                        type=int,
                        help="Score to filter single-view predictions")
    parser.add_argument(
        '--n_symmetries_rot',
        default=64,
        type=int,
        help="Number of discretized symmetries to use for continuous symmetries"
    )
    parser.add_argument(
        '--ransac_n_iter',
        default=2000,
        type=int,
        help="Max number of RANSAC iterations per pair of views")
    parser.add_argument(
        '--ransac_dist_threshold',
        default=0.02,
        type=float,
        help=
        "Threshold (in meters) on symmetric distance to consider a tentative match an inlier"
    )
    parser.add_argument('--ba_n_iter',
                        default=10,
                        type=int,
                        help="Maximum number of LM iterations in stage 3")
    parser.add_argument('--nms_th',
                        default=0.04,
                        type=float,
                        help='Threshold (meter) for NMS 3D')
    parser.add_argument('--no_visualization', action='store_true')
    args = parser.parse_args()

    loggers = [
        logging.getLogger(name) for name in logging.root.manager.loggerDict
    ]
    for logger in loggers:
        if 'cosypose' in logger.name:
            logger.setLevel(logging.DEBUG)

    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    scenario_dir = LOCAL_DATA_DIR / 'custom_scenarios' / args.scenario

    candidates = read_csv_candidates(scenario_dir /
                                     'candidates.csv').float().cuda()
    candidates.infos['group_id'] = 0
    scene_ids = np.unique(candidates.infos['scene_id'])
    assert len(
        scene_ids
    ) == 1, 'Please only provide 6D pose estimations that correspond to the same scene.'
    scene_id = scene_ids.item()
    view_ids = np.unique(candidates.infos['view_id'])
    n_views = len(view_ids)
    logger.info(f'Loaded {len(candidates)} candidates in {n_views} views.')

    cameras = read_cameras(scenario_dir / 'scene_camera.json',
                           view_ids).float().cuda()
    cameras.infos['scene_id'] = scene_id
    cameras.infos['batch_im_id'] = np.arange(len(view_ids))
    logger.info(f'Loaded cameras intrinsics.')

    object_ds = BOPObjectDataset(scenario_dir / 'models')
    mesh_db = MeshDataBase.from_object_ds(object_ds)
    logger.info(f'Loaded {len(object_ds)} 3D object models.')

    logger.info('Running stage 2 and 3 of CosyPose...')
    mv_predictor = MultiviewScenePredictor(mesh_db)
    predictions = mv_predictor.predict_scene_state(
        candidates,
        cameras,
        score_th=args.sv_score_th,
        use_known_camera_poses=False,
        ransac_n_iter=args.ransac_n_iter,
        ransac_dist_threshold=args.ransac_dist_threshold,
        ba_n_iter=args.ba_n_iter)

    objects = predictions['scene/objects']
    cameras = predictions['scene/cameras']
    reproj = predictions['ba_output']

    for view_group in np.unique(objects.infos['view_group']):
        objects_ = objects[np.where(
            objects.infos['view_group'] == view_group)[0]]
        cameras_ = cameras[np.where(
            cameras.infos['view_group'] == view_group)[0]]
        reproj_ = reproj[np.where(reproj.infos['view_group'] == view_group)[0]]
        objects_ = nms3d(objects_, th=args.nms_th, poses_attr='TWO')

        view_group_dir = scenario_dir / 'results' / f'subscene={view_group}'
        view_group_dir.mkdir(exist_ok=True, parents=True)

        logger.info(
            f'Subscene {view_group} has {len(objects_)} objects and {len(cameras_)} cameras.'
        )

        predicted_scene_path = view_group_dir / 'predicted_scene.json'
        scene_reprojected_path = view_group_dir / 'scene_reprojected.csv'
        save_scene_json(objects_, cameras_, predicted_scene_path)
        tc_to_csv(reproj_, scene_reprojected_path)

        logger.info(
            f'Wrote predicted scene (objects+cameras): {predicted_scene_path}')
        logger.info(
            f'Wrote predicted objects with pose expressed in camera frame: {scene_reprojected_path}'
        )
Esempio n. 3
0
def get_pose_meters(scene_ds):
    ds_name = scene_ds.name

    compute_add = False
    spheres_overlap_check = True
    large_match_threshold_diameter_ratio = 0.5
    if ds_name == 'tless.primesense.test.bop19':
        targets_filename = 'test_targets_bop19.json'
        visib_gt_min = -1
        n_top = -1  # Given by targets
    elif ds_name == 'tless.primesense.test':
        targets_filename = 'all_target_tless.json'
        n_top = 1
        visib_gt_min = 0.1
    elif 'ycbv' in ds_name:
        compute_add = True
        visib_gt_min = -1
        targets_filename = None
        n_top = 1
        spheres_overlap_check = False
    else:
        raise ValueError

    if 'tless' in ds_name:
        object_ds_name = 'tless.eval'
    elif 'ycbv' in ds_name:
        object_ds_name = 'ycbv.bop-compat.eval'  # This is important for definition of symmetric objects
    else:
        raise ValueError

    if targets_filename is not None:
        targets_path = scene_ds.ds_dir / targets_filename
        targets = pd.read_json(targets_path)
        targets = remap_bop_targets(targets)
    else:
        targets = None

    object_ds = make_object_dataset(object_ds_name)
    mesh_db = MeshDataBase.from_object_ds(object_ds)

    error_types = ['ADD-S'] + (['ADD(-S)'] if compute_add else [])

    base_kwargs = dict(
        mesh_db=mesh_db,
        exact_meshes=True,
        sample_n_points=None,
        errors_bsz=1,

        # BOP-Like parameters
        n_top=n_top,
        visib_gt_min=visib_gt_min,
        targets=targets,
        spheres_overlap_check=spheres_overlap_check,
    )

    meters = dict()
    for error_type in error_types:
        # For measuring ADD-S AUC on T-LESS and average errors on ycbv/tless.
        meters[f'{error_type}_ntop=BOP_matching=OVERLAP'] = PoseErrorMeter(
            error_type=error_type, consider_all_predictions=False,
            match_threshold=large_match_threshold_diameter_ratio,
            report_error_stats=True, report_error_AUC=True, **base_kwargs)

        if 'ycbv' in ds_name:
            # For fair comparison with PoseCNN/DeepIM on YCB-Video ADD(-S) AUC
            meters[f'{error_type}_ntop=1_matching=CLASS'] = PoseErrorMeter(
                error_type=error_type, consider_all_predictions=False,
                match_threshold=np.inf,
                report_error_stats=False, report_error_AUC=True, **base_kwargs)

        if 'tless' in ds_name:
            meters.update({f'{error_type}_ntop=BOP_matching=BOP':  # For ADD-S<0.1d
                           PoseErrorMeter(error_type=error_type, match_threshold=0.1, **base_kwargs),

                           f'{error_type}_ntop=ALL_matching=BOP':  # For mAP
                           PoseErrorMeter(error_type=error_type, match_threshold=0.1,
                                          consider_all_predictions=True,
                                          report_AP=True, **base_kwargs)})
    return meters
Esempio n. 4
0
def train_pose(args):
    torch.set_num_threads(1)

    if args.resume_run_id:
        resume_dir = EXP_DIR / args.resume_run_id
        resume_args = yaml.load((resume_dir / 'config.yaml').read_text())
        keep_fields = set([
            'resume_run_id',
            'epoch_size',
        ])
        vars(args).update({
            k: v
            for k, v in vars(resume_args).items() if k not in keep_fields
        })

    args.train_refiner = args.TCO_input_generator == 'gt+noise'
    args.train_coarse = not args.train_refiner
    args.save_dir = EXP_DIR / args.run_id

    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    # Initialize distributed
    device = torch.cuda.current_device()
    init_distributed_mode()
    world_size = get_world_size()
    args.n_gpus = world_size
    args.global_batch_size = world_size * args.batch_size
    logger.info(f'Connection established with {world_size} gpus.')

    # Make train/val datasets
    def make_datasets(dataset_names):
        datasets = []
        for (ds_name, n_repeat) in dataset_names:
            assert 'test' not in ds_name
            ds = make_scene_dataset(ds_name)
            logger.info(f'Loaded {ds_name} with {len(ds)} images.')
            for _ in range(n_repeat):
                datasets.append(ds)
        return ConcatDataset(datasets)

    # tracking dataset
    scene_ds_train = make_datasets(args.train_ds_names)
    scene_ds_val = make_datasets(args.val_ds_names)

    ds_kwargs = dict(
        resize=args.input_resize,
        rgb_augmentation=args.rgb_augmentation,
        background_augmentation=args.background_augmentation,
        min_area=args.min_area,
        gray_augmentation=args.gray_augmentation,
    )
    ds_train = PoseTrackingDataset(scene_ds_train, **ds_kwargs)
    ds_val = PoseTrackingDataset(scene_ds_val, **ds_kwargs)

    train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size)
    ds_iter_train = DataLoader(ds_train,
                               sampler=train_sampler,
                               batch_size=args.batch_size,
                               num_workers=args.n_dataloader_workers,
                               collate_fn=ds_train.collate_fn,
                               drop_last=False,
                               pin_memory=True)
    ds_iter_train = MultiEpochDataLoader(ds_iter_train)

    val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size))
    ds_iter_val = DataLoader(ds_val,
                             sampler=val_sampler,
                             batch_size=args.batch_size,
                             num_workers=args.n_dataloader_workers,
                             collate_fn=ds_val.collate_fn,
                             drop_last=False,
                             pin_memory=True)
    ds_iter_val = MultiEpochDataLoader(ds_iter_val)

    # Make model
    # renderer = BulletBatchRenderer(object_set=args.urdf_ds_name, n_workers=args.n_rendering_workers)
    object_ds = make_object_dataset(args.object_ds_name)
    mesh_db = MeshDataBase.from_object_ds(object_ds).batched(
        n_sym=args.n_symmetries_batch).cuda().float()

    model = create_model_pose_custom(cfg=args, mesh_db=mesh_db).cuda()

    eval_bundle = make_eval_bundle(args, model)

    if args.resume_run_id:
        resume_dir = EXP_DIR / args.resume_run_id
        path = resume_dir / 'checkpoint.pth.tar'
        logger.info(f'Loading checkpoing from {path}')
        save = torch.load(path)
        state_dict = save['state_dict']
        model.load_state_dict(state_dict)
        start_epoch = save['epoch'] + 1
    else:
        start_epoch = 0
    end_epoch = args.n_epochs

    if args.run_id_pretrain is not None:
        pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar'
        logger.info(f'Using pretrained model from {pretrain_path}.')
        model.load_state_dict(torch.load(pretrain_path)['state_dict'])

    # Synchronize models across processes.
    model = sync_model(model)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[device],
                                                      output_device=device)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # Warmup
    if args.n_epochs_warmup == 0:
        lambd = lambda epoch: 1
    else:
        n_batches_warmup = args.n_epochs_warmup * (args.epoch_size //
                                                   args.batch_size)
        lambd = lambda batch: (batch + 1) / n_batches_warmup
    lr_scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambd)
    lr_scheduler_warmup.last_epoch = start_epoch * args.epoch_size // args.batch_size

    # LR schedulers
    # Divide LR by 10 every args.lr_epoch_decay
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=args.lr_epoch_decay,
        gamma=0.1,
    )
    lr_scheduler.last_epoch = start_epoch - 1
    lr_scheduler.step()

    for epoch in range(start_epoch, end_epoch):
        meters_train = defaultdict(lambda: AverageValueMeter())
        meters_val = defaultdict(lambda: AverageValueMeter())
        meters_time = defaultdict(lambda: AverageValueMeter())

        h = functools.partial(h_pose_custom,
                              model=model,
                              cfg=args,
                              n_iterations=args.n_iterations,
                              mesh_db=mesh_db,
                              input_generator=args.TCO_input_generator)

        def train_epoch():
            model.train()
            iterator = tqdm(ds_iter_train, ncols=80)
            t = time.time()
            for n, sample in enumerate(iterator):
                if n > 0:
                    meters_time['data'].add(time.time() - t)

                optimizer.zero_grad()

                t = time.time()
                loss = h(data=sample, meters=meters_train)
                meters_time['forward'].add(time.time() - t)
                iterator.set_postfix(loss=loss.item())
                meters_train['loss_total'].add(loss.item())

                t = time.time()
                loss.backward()
                total_grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    max_norm=args.clip_grad_norm,
                    norm_type=2)
                meters_train['grad_norm'].add(
                    torch.as_tensor(total_grad_norm).item())

                optimizer.step()
                meters_time['backward'].add(time.time() - t)
                meters_time['memory'].add(torch.cuda.max_memory_allocated() /
                                          1024.**2)

                if epoch < args.n_epochs_warmup:
                    lr_scheduler_warmup.step()
                t = time.time()
            if epoch >= args.n_epochs_warmup:
                lr_scheduler.step()

        @torch.no_grad()
        def validation():
            model.eval()
            for sample in tqdm(ds_iter_val, ncols=80):
                loss = h(data=sample, meters=meters_val)
                meters_val['loss_total'].add(loss.item())

        @torch.no_grad()
        def test():
            model.eval()
            return run_eval(eval_bundle, epoch=epoch)

        train_epoch()
        if epoch % args.val_epoch_interval == 0:
            validation()

        test_dict = None
        if epoch % args.test_epoch_interval == 0:
            test_dict = test()

        log_dict = dict()
        log_dict.update({
            'grad_norm':
            meters_train['grad_norm'].mean,
            'grad_norm_std':
            meters_train['grad_norm'].std,
            'learning_rate':
            optimizer.param_groups[0]['lr'],
            'time_forward':
            meters_time['forward'].mean,
            'time_backward':
            meters_time['backward'].mean,
            'time_data':
            meters_time['data'].mean,
            'gpu_memory':
            meters_time['memory'].mean,
            'time':
            time.time(),
            'n_iterations': (epoch + 1) * len(ds_iter_train),
            'n_datas':
            (epoch + 1) * args.global_batch_size * len(ds_iter_train),
        })

        for string, meters in zip(('train', 'val'),
                                  (meters_train, meters_val)):
            for k in dict(meters).keys():
                log_dict[f'{string}_{k}'] = meters[k].mean

        log_dict = reduce_dict(log_dict)
        if get_rank() == 0:
            log(config=args,
                model=model,
                epoch=epoch,
                log_dict=log_dict,
                test_dict=test_dict)
        dist.barrier()