def load_models(coarse_run_id, refiner_run_id=None, n_workers=8, object_set='tless'): if object_set == 'tless': object_ds_name, urdf_ds_name = 'tless.bop', 'tless.cad' else: object_ds_name, urdf_ds_name = 'ycbv.bop-compat.eval', 'ycbv' object_ds = make_object_dataset(object_ds_name) mesh_db = MeshDataBase.from_object_ds(object_ds) renderer = BulletBatchRenderer(object_set=urdf_ds_name, n_workers=n_workers) mesh_db_batched = mesh_db.batched().cuda() def load_model(run_id): if run_id is None: return run_dir = EXP_DIR / run_id cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader) cfg = check_update_config(cfg) if cfg.train_refiner: model = create_model_refiner(cfg, renderer=renderer, mesh_db=mesh_db_batched) ckpt = torch.load(run_dir / 'checkpoint.pth.tar') else: model = create_model_coarse(cfg, renderer=renderer, mesh_db=mesh_db_batched) ckpt = torch.load(run_dir / 'checkpoint.pth.tar') ckpt = ckpt['state_dict'] model.load_state_dict(ckpt) model = model.cuda().eval() model.cfg = cfg return model coarse_model = load_model(coarse_run_id) refiner_model = load_model(refiner_run_id) model = CoarseRefinePosePredictor(coarse_model=coarse_model, refiner_model=refiner_model) return model, mesh_db
def main(): parser = argparse.ArgumentParser( 'CosyPose multi-view reconstruction for a custom scenario') parser.add_argument( '--scenario', default='', type=str, help= 'Id of the scenario, matching directory must be in local_data/scenarios' ) parser.add_argument('--sv_score_th', default=0.3, type=int, help="Score to filter single-view predictions") parser.add_argument( '--n_symmetries_rot', default=64, type=int, help="Number of discretized symmetries to use for continuous symmetries" ) parser.add_argument( '--ransac_n_iter', default=2000, type=int, help="Max number of RANSAC iterations per pair of views") parser.add_argument( '--ransac_dist_threshold', default=0.02, type=float, help= "Threshold (in meters) on symmetric distance to consider a tentative match an inlier" ) parser.add_argument('--ba_n_iter', default=10, type=int, help="Maximum number of LM iterations in stage 3") parser.add_argument('--nms_th', default=0.04, type=float, help='Threshold (meter) for NMS 3D') parser.add_argument('--no_visualization', action='store_true') args = parser.parse_args() loggers = [ logging.getLogger(name) for name in logging.root.manager.loggerDict ] for logger in loggers: if 'cosypose' in logger.name: logger.setLevel(logging.DEBUG) logger.info(f"{'-'*80}") for k, v in args.__dict__.items(): logger.info(f"{k}: {v}") logger.info(f"{'-'*80}") scenario_dir = LOCAL_DATA_DIR / 'custom_scenarios' / args.scenario candidates = read_csv_candidates(scenario_dir / 'candidates.csv').float().cuda() candidates.infos['group_id'] = 0 scene_ids = np.unique(candidates.infos['scene_id']) assert len( scene_ids ) == 1, 'Please only provide 6D pose estimations that correspond to the same scene.' scene_id = scene_ids.item() view_ids = np.unique(candidates.infos['view_id']) n_views = len(view_ids) logger.info(f'Loaded {len(candidates)} candidates in {n_views} views.') cameras = read_cameras(scenario_dir / 'scene_camera.json', view_ids).float().cuda() cameras.infos['scene_id'] = scene_id cameras.infos['batch_im_id'] = np.arange(len(view_ids)) logger.info(f'Loaded cameras intrinsics.') object_ds = BOPObjectDataset(scenario_dir / 'models') mesh_db = MeshDataBase.from_object_ds(object_ds) logger.info(f'Loaded {len(object_ds)} 3D object models.') logger.info('Running stage 2 and 3 of CosyPose...') mv_predictor = MultiviewScenePredictor(mesh_db) predictions = mv_predictor.predict_scene_state( candidates, cameras, score_th=args.sv_score_th, use_known_camera_poses=False, ransac_n_iter=args.ransac_n_iter, ransac_dist_threshold=args.ransac_dist_threshold, ba_n_iter=args.ba_n_iter) objects = predictions['scene/objects'] cameras = predictions['scene/cameras'] reproj = predictions['ba_output'] for view_group in np.unique(objects.infos['view_group']): objects_ = objects[np.where( objects.infos['view_group'] == view_group)[0]] cameras_ = cameras[np.where( cameras.infos['view_group'] == view_group)[0]] reproj_ = reproj[np.where(reproj.infos['view_group'] == view_group)[0]] objects_ = nms3d(objects_, th=args.nms_th, poses_attr='TWO') view_group_dir = scenario_dir / 'results' / f'subscene={view_group}' view_group_dir.mkdir(exist_ok=True, parents=True) logger.info( f'Subscene {view_group} has {len(objects_)} objects and {len(cameras_)} cameras.' ) predicted_scene_path = view_group_dir / 'predicted_scene.json' scene_reprojected_path = view_group_dir / 'scene_reprojected.csv' save_scene_json(objects_, cameras_, predicted_scene_path) tc_to_csv(reproj_, scene_reprojected_path) logger.info( f'Wrote predicted scene (objects+cameras): {predicted_scene_path}') logger.info( f'Wrote predicted objects with pose expressed in camera frame: {scene_reprojected_path}' )
def get_pose_meters(scene_ds): ds_name = scene_ds.name compute_add = False spheres_overlap_check = True large_match_threshold_diameter_ratio = 0.5 if ds_name == 'tless.primesense.test.bop19': targets_filename = 'test_targets_bop19.json' visib_gt_min = -1 n_top = -1 # Given by targets elif ds_name == 'tless.primesense.test': targets_filename = 'all_target_tless.json' n_top = 1 visib_gt_min = 0.1 elif 'ycbv' in ds_name: compute_add = True visib_gt_min = -1 targets_filename = None n_top = 1 spheres_overlap_check = False else: raise ValueError if 'tless' in ds_name: object_ds_name = 'tless.eval' elif 'ycbv' in ds_name: object_ds_name = 'ycbv.bop-compat.eval' # This is important for definition of symmetric objects else: raise ValueError if targets_filename is not None: targets_path = scene_ds.ds_dir / targets_filename targets = pd.read_json(targets_path) targets = remap_bop_targets(targets) else: targets = None object_ds = make_object_dataset(object_ds_name) mesh_db = MeshDataBase.from_object_ds(object_ds) error_types = ['ADD-S'] + (['ADD(-S)'] if compute_add else []) base_kwargs = dict( mesh_db=mesh_db, exact_meshes=True, sample_n_points=None, errors_bsz=1, # BOP-Like parameters n_top=n_top, visib_gt_min=visib_gt_min, targets=targets, spheres_overlap_check=spheres_overlap_check, ) meters = dict() for error_type in error_types: # For measuring ADD-S AUC on T-LESS and average errors on ycbv/tless. meters[f'{error_type}_ntop=BOP_matching=OVERLAP'] = PoseErrorMeter( error_type=error_type, consider_all_predictions=False, match_threshold=large_match_threshold_diameter_ratio, report_error_stats=True, report_error_AUC=True, **base_kwargs) if 'ycbv' in ds_name: # For fair comparison with PoseCNN/DeepIM on YCB-Video ADD(-S) AUC meters[f'{error_type}_ntop=1_matching=CLASS'] = PoseErrorMeter( error_type=error_type, consider_all_predictions=False, match_threshold=np.inf, report_error_stats=False, report_error_AUC=True, **base_kwargs) if 'tless' in ds_name: meters.update({f'{error_type}_ntop=BOP_matching=BOP': # For ADD-S<0.1d PoseErrorMeter(error_type=error_type, match_threshold=0.1, **base_kwargs), f'{error_type}_ntop=ALL_matching=BOP': # For mAP PoseErrorMeter(error_type=error_type, match_threshold=0.1, consider_all_predictions=True, report_AP=True, **base_kwargs)}) return meters
def train_pose(args): torch.set_num_threads(1) if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id resume_args = yaml.load((resume_dir / 'config.yaml').read_text()) keep_fields = set([ 'resume_run_id', 'epoch_size', ]) vars(args).update({ k: v for k, v in vars(resume_args).items() if k not in keep_fields }) args.train_refiner = args.TCO_input_generator == 'gt+noise' args.train_coarse = not args.train_refiner args.save_dir = EXP_DIR / args.run_id logger.info(f"{'-'*80}") for k, v in args.__dict__.items(): logger.info(f"{k}: {v}") logger.info(f"{'-'*80}") # Initialize distributed device = torch.cuda.current_device() init_distributed_mode() world_size = get_world_size() args.n_gpus = world_size args.global_batch_size = world_size * args.batch_size logger.info(f'Connection established with {world_size} gpus.') # Make train/val datasets def make_datasets(dataset_names): datasets = [] for (ds_name, n_repeat) in dataset_names: assert 'test' not in ds_name ds = make_scene_dataset(ds_name) logger.info(f'Loaded {ds_name} with {len(ds)} images.') for _ in range(n_repeat): datasets.append(ds) return ConcatDataset(datasets) # tracking dataset scene_ds_train = make_datasets(args.train_ds_names) scene_ds_val = make_datasets(args.val_ds_names) ds_kwargs = dict( resize=args.input_resize, rgb_augmentation=args.rgb_augmentation, background_augmentation=args.background_augmentation, min_area=args.min_area, gray_augmentation=args.gray_augmentation, ) ds_train = PoseTrackingDataset(scene_ds_train, **ds_kwargs) ds_val = PoseTrackingDataset(scene_ds_val, **ds_kwargs) train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size) ds_iter_train = DataLoader(ds_train, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=ds_train.collate_fn, drop_last=False, pin_memory=True) ds_iter_train = MultiEpochDataLoader(ds_iter_train) val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size)) ds_iter_val = DataLoader(ds_val, sampler=val_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=ds_val.collate_fn, drop_last=False, pin_memory=True) ds_iter_val = MultiEpochDataLoader(ds_iter_val) # Make model # renderer = BulletBatchRenderer(object_set=args.urdf_ds_name, n_workers=args.n_rendering_workers) object_ds = make_object_dataset(args.object_ds_name) mesh_db = MeshDataBase.from_object_ds(object_ds).batched( n_sym=args.n_symmetries_batch).cuda().float() model = create_model_pose_custom(cfg=args, mesh_db=mesh_db).cuda() eval_bundle = make_eval_bundle(args, model) if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id path = resume_dir / 'checkpoint.pth.tar' logger.info(f'Loading checkpoing from {path}') save = torch.load(path) state_dict = save['state_dict'] model.load_state_dict(state_dict) start_epoch = save['epoch'] + 1 else: start_epoch = 0 end_epoch = args.n_epochs if args.run_id_pretrain is not None: pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar' logger.info(f'Using pretrained model from {pretrain_path}.') model.load_state_dict(torch.load(pretrain_path)['state_dict']) # Synchronize models across processes. model = sync_model(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device) # Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # Warmup if args.n_epochs_warmup == 0: lambd = lambda epoch: 1 else: n_batches_warmup = args.n_epochs_warmup * (args.epoch_size // args.batch_size) lambd = lambda batch: (batch + 1) / n_batches_warmup lr_scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambd) lr_scheduler_warmup.last_epoch = start_epoch * args.epoch_size // args.batch_size # LR schedulers # Divide LR by 10 every args.lr_epoch_decay lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.lr_epoch_decay, gamma=0.1, ) lr_scheduler.last_epoch = start_epoch - 1 lr_scheduler.step() for epoch in range(start_epoch, end_epoch): meters_train = defaultdict(lambda: AverageValueMeter()) meters_val = defaultdict(lambda: AverageValueMeter()) meters_time = defaultdict(lambda: AverageValueMeter()) h = functools.partial(h_pose_custom, model=model, cfg=args, n_iterations=args.n_iterations, mesh_db=mesh_db, input_generator=args.TCO_input_generator) def train_epoch(): model.train() iterator = tqdm(ds_iter_train, ncols=80) t = time.time() for n, sample in enumerate(iterator): if n > 0: meters_time['data'].add(time.time() - t) optimizer.zero_grad() t = time.time() loss = h(data=sample, meters=meters_train) meters_time['forward'].add(time.time() - t) iterator.set_postfix(loss=loss.item()) meters_train['loss_total'].add(loss.item()) t = time.time() loss.backward() total_grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=args.clip_grad_norm, norm_type=2) meters_train['grad_norm'].add( torch.as_tensor(total_grad_norm).item()) optimizer.step() meters_time['backward'].add(time.time() - t) meters_time['memory'].add(torch.cuda.max_memory_allocated() / 1024.**2) if epoch < args.n_epochs_warmup: lr_scheduler_warmup.step() t = time.time() if epoch >= args.n_epochs_warmup: lr_scheduler.step() @torch.no_grad() def validation(): model.eval() for sample in tqdm(ds_iter_val, ncols=80): loss = h(data=sample, meters=meters_val) meters_val['loss_total'].add(loss.item()) @torch.no_grad() def test(): model.eval() return run_eval(eval_bundle, epoch=epoch) train_epoch() if epoch % args.val_epoch_interval == 0: validation() test_dict = None if epoch % args.test_epoch_interval == 0: test_dict = test() log_dict = dict() log_dict.update({ 'grad_norm': meters_train['grad_norm'].mean, 'grad_norm_std': meters_train['grad_norm'].std, 'learning_rate': optimizer.param_groups[0]['lr'], 'time_forward': meters_time['forward'].mean, 'time_backward': meters_time['backward'].mean, 'time_data': meters_time['data'].mean, 'gpu_memory': meters_time['memory'].mean, 'time': time.time(), 'n_iterations': (epoch + 1) * len(ds_iter_train), 'n_datas': (epoch + 1) * args.global_batch_size * len(ds_iter_train), }) for string, meters in zip(('train', 'val'), (meters_train, meters_val)): for k in dict(meters).keys(): log_dict[f'{string}_{k}'] = meters[k].mean log_dict = reduce_dict(log_dict) if get_rank() == 0: log(config=args, model=model, epoch=epoch, log_dict=log_dict, test_dict=test_dict) dist.barrier()