def main(): loggers = [ logging.getLogger(name) for name in logging.root.manager.loggerDict ] for logger in loggers: if 'cosypose' in logger.name: logger.setLevel(logging.DEBUG) parser = argparse.ArgumentParser('Evaluation') parser.add_argument('--debug', action='store_true') parser.add_argument('--comment', default='', type=str) parser.add_argument('--id', default=-1, type=int) parser.add_argument('--config', default='bop-pbr', type=str) parser.add_argument('--nviews', dest='n_views', default=1, type=int) parser.add_argument('--icp', action='store_true') args = parser.parse_args() init_distributed_mode() cfg = argparse.ArgumentParser('').parse_args([]) cfg.n_workers = 8 cfg.pred_bsz = 1 cfg.n_frames = None cfg.n_groups = None cfg.skip_evaluation = False cfg.external_predictions = True cfg.n_coarse_iterations = 1 cfg.n_refiner_iterations = 4 cfg.icp = args.icp cfg.debug = args.debug cfg.n_views = args.n_views if args.debug: if args.n_views > 1: cfg.n_groups = 1 else: cfg.n_frames = 4 # cfg.n_workers = 1 if args.id < 0: n_rand = np.random.randint(1e6) args.id = n_rand if args.icp: args.comment = f'icp-{args.comment}' if args.n_views > 1: args.comment = f'nviews={args.n_views}-{args.comment}' save_dir = RESULTS_DIR / f'{args.config}-{args.comment}-{args.id}' logger.info(f'Save dir: {save_dir}') if args.config == 'bop-pbr': MODELS_DETECTORS = PBR_DETECTORS MODELS_COARSE = PBR_COARSE MODELS_REFINER = PBR_REFINER elif args.config == 'bop-synt+real': MODELS_DETECTORS = SYNT_REAL_DETECTORS MODELS_COARSE = SYNT_REAL_COARSE MODELS_REFINER = SYNT_REAL_REFINER if args.n_views > 1: ds_names = ['hb', 'tless', 'ycbv'] else: ds_names = ['hb', 'icbin', 'itodd', 'lmo', 'tless', 'tudl', 'ycbv'] for ds_name in ds_names: this_cfg = deepcopy(cfg) this_cfg.ds_name = BOP_CONFIG[ds_name]['inference_ds_name'][0] this_cfg.save_dir = save_dir / f'dataset={ds_name}' this_cfg.detector_run_id = MODELS_DETECTORS.get(ds_name) this_cfg.coarse_run_id = MODELS_COARSE.get(ds_name) this_cfg.refiner_run_id = MODELS_REFINER.get(ds_name) if this_cfg.detector_run_id is None \ or this_cfg.coarse_run_id is None \ or this_cfg.refiner_run_id is None: logger.info(f'Skipped {ds_name}') continue run_inference(this_cfg)
def main(): loggers = [ logging.getLogger(name) for name in logging.root.manager.loggerDict ] for logger in loggers: if 'cosypose' in logger.name: logger.setLevel(logging.DEBUG) parser = argparse.ArgumentParser('Evaluation') parser.add_argument('--debug', action='store_true') parser.add_argument('--skip_predictions', action='store_true') parser.add_argument('--comment', default='', type=str) parser.add_argument('--id', default=-1, type=int) parser.add_argument('--config', default='', type=str) parser.add_argument('--models', default='', type=str) args = parser.parse_args() init_distributed_mode() cfg = argparse.ArgumentParser('').parse_args([]) cfg.n_workers = 8 cfg.pred_bsz = 8 cfg.eval_bsz = 8 cfg.n_frames = None cfg.skip_evaluation = False cfg.skip_model_predictions = args.skip_predictions cfg.external_predictions = True cfg.detector = None if args.debug: cfg.n_frames = 10 if args.config == 'bop': # ds_names = ['ycbv.bop19', 'tless.bop19'] ds_names = ['itodd.val', 'hb.val'] else: raise ValueError detector_run_ids = { 'ycbv.bop19': 'ycbv--377940', 'hb.val': 'detector-bop-hb--497808', 'itodd.val': 'detector-bop-itodd--509908', } if args.id < 0: n_rand = np.random.randint(1e6) args.id = n_rand save_dir = RESULTS_DIR / f'{args.config}-{args.models}-{args.comment}-{args.id}' logger.info(f'Save dir: {save_dir}') for ds_name in ds_names: this_cfg = deepcopy(cfg) this_cfg.ds_name = ds_name this_cfg.save_dir = save_dir / f'dataset={ds_name}' logger.info(f'DATASET: {ds_name}') if ds_name in detector_run_ids: this_cfg.detector_run_id = detector_run_ids[ds_name] else: this_cfg.skip_model_predictions = True logger.info(f'No model provided for dataset: {ds_name}.') run_detection_eval(this_cfg) logger.info('')
def main(): loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] for logger in loggers: if 'cosypose' in logger.name: logger.setLevel(logging.DEBUG) logger.info("Starting ...") init_distributed_mode() parser = argparse.ArgumentParser('Evaluation') parser.add_argument('--config', default='tless-bop', type=str) parser.add_argument('--debug', action='store_true') parser.add_argument('--job_dir', default='', type=str) parser.add_argument('--comment', default='', type=str) parser.add_argument('--nviews', dest='n_views', default=1, type=int) args = parser.parse_args() coarse_run_id = None refiner_run_id = None n_workers = 8 n_plotters = 8 n_views = 1 n_frames = None scene_id = None group_id = None n_groups = None n_views = args.n_views skip_mv = args.n_views < 2 skip_predictions = False object_set = 'tless' if 'tless' in args.config: object_set = 'tless' coarse_run_id = 'tless-coarse--10219' refiner_run_id = 'tless-refiner--585928' n_coarse_iterations = 1 n_refiner_iterations = 4 elif 'ycbv' in args.config: object_set = 'ycbv' refiner_run_id = 'ycbv-refiner-finetune--251020' n_coarse_iterations = 0 n_refiner_iterations = 2 else: raise ValueError(args.config) if args.config == 'tless-siso': ds_name = 'tless.primesense.test' assert n_views == 1 elif args.config == 'tless-vivo': ds_name = 'tless.primesense.test.bop19' elif args.config == 'ycbv': ds_name = 'ycbv.test.keyframes' else: raise ValueError(args.config) if args.debug: if 'tless' in args.config: scene_id = None group_id = 64 n_groups = 2 else: scene_id = 48 n_groups = 2 n_frames = None n_workers = 0 n_plotters = 0 n_rand = np.random.randint(1e10) save_dir = RESULTS_DIR / f'{args.config}-n_views={n_views}-{args.comment}-{n_rand}' logger.info(f"SAVE DIR: {save_dir}") logger.info(f"Coarse: {coarse_run_id}") logger.info(f"Refiner: {refiner_run_id}") # Load dataset scene_ds = make_scene_dataset(ds_name) if scene_id is not None: mask = scene_ds.frame_index['scene_id'] == scene_id scene_ds.frame_index = scene_ds.frame_index[mask].reset_index(drop=True) if n_frames is not None: scene_ds.frame_index = scene_ds.frame_index[mask].reset_index(drop=True)[:n_frames] # Predictions predictor, mesh_db = load_models(coarse_run_id, refiner_run_id, n_workers=n_plotters, object_set=object_set) mv_predictor = MultiviewScenePredictor(mesh_db) base_pred_kwargs = dict( n_coarse_iterations=n_coarse_iterations, n_refiner_iterations=n_refiner_iterations, skip_mv=skip_mv, pose_predictor=predictor, mv_predictor=mv_predictor, ) if skip_predictions: pred_kwargs = {} elif 'tless' in ds_name: pix2pose_detections = load_pix2pose_results(all_detections='bop19' in ds_name).cpu() pred_kwargs = { 'pix2pose_detections': dict( detections=pix2pose_detections, **base_pred_kwargs ), } elif 'ycbv' in ds_name: posecnn_detections = load_posecnn_results() pred_kwargs = { 'posecnn_init': dict( detections=posecnn_detections, use_detections_TCO=posecnn_detections, **base_pred_kwargs ), } else: raise ValueError(ds_name) scene_ds_pred = MultiViewWrapper(scene_ds, n_views=n_views) if group_id is not None: mask = scene_ds_pred.frame_index['group_id'] == group_id scene_ds_pred.frame_index = scene_ds_pred.frame_index[mask].reset_index(drop=True) elif n_groups is not None: scene_ds_pred.frame_index = scene_ds_pred.frame_index[:n_groups] pred_runner = MultiviewPredictionRunner( scene_ds_pred, batch_size=1, n_workers=n_workers, cache_data=len(pred_kwargs) > 1) all_predictions = dict() for pred_prefix, pred_kwargs_n in pred_kwargs.items(): logger.info(f"Prediction: {pred_prefix}") preds = pred_runner.get_predictions(**pred_kwargs_n) for preds_name, preds_n in preds.items(): all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n logger.info("Done with predictions") torch.distributed.barrier() # Evaluation predictions_to_evaluate = set() if 'ycbv' in ds_name: det_key = 'posecnn_init' all_predictions['posecnn'] = posecnn_detections predictions_to_evaluate.add('posecnn') elif 'tless' in ds_name: det_key = 'pix2pose_detections' else: raise ValueError(ds_name) predictions_to_evaluate.add(f'{det_key}/refiner/iteration={n_refiner_iterations}') if args.n_views > 1: for k in [ # f'ba_input', # f'ba_output', f'ba_output+all_cand' ]: predictions_to_evaluate.add(f'{det_key}/{k}') all_predictions = OrderedDict({k: v for k, v in sorted(all_predictions.items(), key=lambda item: item[0])}) # Evaluation. meters = get_pose_meters(scene_ds) mv_group_ids = list(iter(pred_runner.sampler)) scene_ds_ids = np.concatenate(scene_ds_pred.frame_index.loc[mv_group_ids, 'scene_ds_ids'].values) sampler = ListSampler(scene_ds_ids) eval_runner = PoseEvaluation(scene_ds, meters, n_workers=n_workers, cache_data=True, batch_size=1, sampler=sampler) eval_metrics, eval_dfs = dict(), dict() for preds_k, preds in all_predictions.items(): if preds_k in predictions_to_evaluate: logger.info(f"Evaluation : {preds_k} (N={len(preds)})") if len(preds) == 0: preds = eval_runner.make_empty_predictions() eval_metrics[preds_k], eval_dfs[preds_k] = eval_runner.evaluate(preds) preds.cpu() else: logger.info(f"Skipped: {preds_k} (N={len(preds)})") all_predictions = gather_predictions(all_predictions) metrics_to_print = dict() if 'ycbv' in ds_name: metrics_to_print.update({ f'posecnn/ADD(-S)_ntop=1_matching=CLASS/AUC/objects/mean': f'PoseCNN/AUC of ADD(-S)', f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD(-S)_ntop=1_matching=CLASS/AUC/objects/mean': f'Singleview/AUC of ADD(-S)', f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD-S_ntop=1_matching=CLASS/AUC/objects/mean': f'Singleview/AUC of ADD-S', f'{det_key}/ba_output+all_cand/ADD(-S)_ntop=1_matching=CLASS/AUC/objects/mean': f'Multiview (n={args.n_views})/AUC of ADD(-S)', f'{det_key}/ba_output+all_cand/ADD-S_ntop=1_matching=CLASS/AUC/objects/mean': f'Multiview (n={args.n_views})/AUC of ADD-S', }) elif 'tless' in ds_name: metrics_to_print.update({ f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD-S_ntop=BOP_matching=OVERLAP/AUC/objects/mean': f'Singleview/AUC of ADD-S', # f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD-S_ntop=BOP_matching=BOP/0.1d': f'Singleview/ADD-S<0.1d', f'{det_key}/refiner/iteration={n_refiner_iterations}/ADD-S_ntop=ALL_matching=BOP/mAP': f'Singleview/mAP@ADD-S<0.1d', f'{det_key}/ba_output+all_cand/ADD-S_ntop=BOP_matching=OVERLAP/AUC/objects/mean': f'Multiview (n={args.n_views})/AUC of ADD-S', # f'{det_key}/ba_output+all_cand/ADD-S_ntop=BOP_matching=BOP/0.1d': f'Multiview (n={args.n_views})/ADD-S<0.1d', f'{det_key}/ba_output+all_cand/ADD-S_ntop=ALL_matching=BOP/mAP': f'Multiview (n={args.n_views}/mAP@ADD-S<0.1d)', }) else: raise ValueError metrics_to_print.update({ f'{det_key}/ba_input/ADD-S_ntop=BOP_matching=OVERLAP/norm': f'Multiview before BA/ADD-S (m)', f'{det_key}/ba_output/ADD-S_ntop=BOP_matching=OVERLAP/norm': f'Multiview after BA/ADD-S (m)', }) if get_rank() == 0: save_dir.mkdir() results = format_results(all_predictions, eval_metrics, eval_dfs, print_metrics=False) (save_dir / 'full_summary.txt').write_text(results.get('summary_txt', '')) full_summary = results['summary'] summary_txt = 'Results:' for k, v in metrics_to_print.items(): if k in full_summary: summary_txt += f"\n{v}: {full_summary[k]}" logger.info(f"{'-'*80}") logger.info(summary_txt) logger.info(f"{'-'*80}") torch.save(results, save_dir / 'results.pth.tar') (save_dir / 'summary.txt').write_text(summary_txt) logger.info(f"Saved: {save_dir}")
def train_pose(args): torch.set_num_threads(1) if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id resume_args = yaml.load((resume_dir / 'config.yaml').read_text()) keep_fields = set([ 'resume_run_id', 'epoch_size', ]) vars(args).update({ k: v for k, v in vars(resume_args).items() if k not in keep_fields }) args.train_refiner = args.TCO_input_generator == 'gt+noise' args.train_coarse = not args.train_refiner args.save_dir = EXP_DIR / args.run_id logger.info(f"{'-'*80}") for k, v in args.__dict__.items(): logger.info(f"{k}: {v}") logger.info(f"{'-'*80}") # Initialize distributed device = torch.cuda.current_device() init_distributed_mode() world_size = get_world_size() args.n_gpus = world_size args.global_batch_size = world_size * args.batch_size logger.info(f'Connection established with {world_size} gpus.') # Make train/val datasets def make_datasets(dataset_names): datasets = [] for (ds_name, n_repeat) in dataset_names: assert 'test' not in ds_name ds = make_scene_dataset(ds_name) logger.info(f'Loaded {ds_name} with {len(ds)} images.') for _ in range(n_repeat): datasets.append(ds) return ConcatDataset(datasets) # tracking dataset scene_ds_train = make_datasets(args.train_ds_names) scene_ds_val = make_datasets(args.val_ds_names) ds_kwargs = dict( resize=args.input_resize, rgb_augmentation=args.rgb_augmentation, background_augmentation=args.background_augmentation, min_area=args.min_area, gray_augmentation=args.gray_augmentation, ) ds_train = PoseTrackingDataset(scene_ds_train, **ds_kwargs) ds_val = PoseTrackingDataset(scene_ds_val, **ds_kwargs) train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size) ds_iter_train = DataLoader(ds_train, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=ds_train.collate_fn, drop_last=False, pin_memory=True) ds_iter_train = MultiEpochDataLoader(ds_iter_train) val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size)) ds_iter_val = DataLoader(ds_val, sampler=val_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=ds_val.collate_fn, drop_last=False, pin_memory=True) ds_iter_val = MultiEpochDataLoader(ds_iter_val) # Make model # renderer = BulletBatchRenderer(object_set=args.urdf_ds_name, n_workers=args.n_rendering_workers) object_ds = make_object_dataset(args.object_ds_name) mesh_db = MeshDataBase.from_object_ds(object_ds).batched( n_sym=args.n_symmetries_batch).cuda().float() model = create_model_pose_custom(cfg=args, mesh_db=mesh_db).cuda() eval_bundle = make_eval_bundle(args, model) if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id path = resume_dir / 'checkpoint.pth.tar' logger.info(f'Loading checkpoing from {path}') save = torch.load(path) state_dict = save['state_dict'] model.load_state_dict(state_dict) start_epoch = save['epoch'] + 1 else: start_epoch = 0 end_epoch = args.n_epochs if args.run_id_pretrain is not None: pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar' logger.info(f'Using pretrained model from {pretrain_path}.') model.load_state_dict(torch.load(pretrain_path)['state_dict']) # Synchronize models across processes. model = sync_model(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device) # Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # Warmup if args.n_epochs_warmup == 0: lambd = lambda epoch: 1 else: n_batches_warmup = args.n_epochs_warmup * (args.epoch_size // args.batch_size) lambd = lambda batch: (batch + 1) / n_batches_warmup lr_scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambd) lr_scheduler_warmup.last_epoch = start_epoch * args.epoch_size // args.batch_size # LR schedulers # Divide LR by 10 every args.lr_epoch_decay lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.lr_epoch_decay, gamma=0.1, ) lr_scheduler.last_epoch = start_epoch - 1 lr_scheduler.step() for epoch in range(start_epoch, end_epoch): meters_train = defaultdict(lambda: AverageValueMeter()) meters_val = defaultdict(lambda: AverageValueMeter()) meters_time = defaultdict(lambda: AverageValueMeter()) h = functools.partial(h_pose_custom, model=model, cfg=args, n_iterations=args.n_iterations, mesh_db=mesh_db, input_generator=args.TCO_input_generator) def train_epoch(): model.train() iterator = tqdm(ds_iter_train, ncols=80) t = time.time() for n, sample in enumerate(iterator): if n > 0: meters_time['data'].add(time.time() - t) optimizer.zero_grad() t = time.time() loss = h(data=sample, meters=meters_train) meters_time['forward'].add(time.time() - t) iterator.set_postfix(loss=loss.item()) meters_train['loss_total'].add(loss.item()) t = time.time() loss.backward() total_grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=args.clip_grad_norm, norm_type=2) meters_train['grad_norm'].add( torch.as_tensor(total_grad_norm).item()) optimizer.step() meters_time['backward'].add(time.time() - t) meters_time['memory'].add(torch.cuda.max_memory_allocated() / 1024.**2) if epoch < args.n_epochs_warmup: lr_scheduler_warmup.step() t = time.time() if epoch >= args.n_epochs_warmup: lr_scheduler.step() @torch.no_grad() def validation(): model.eval() for sample in tqdm(ds_iter_val, ncols=80): loss = h(data=sample, meters=meters_val) meters_val['loss_total'].add(loss.item()) @torch.no_grad() def test(): model.eval() return run_eval(eval_bundle, epoch=epoch) train_epoch() if epoch % args.val_epoch_interval == 0: validation() test_dict = None if epoch % args.test_epoch_interval == 0: test_dict = test() log_dict = dict() log_dict.update({ 'grad_norm': meters_train['grad_norm'].mean, 'grad_norm_std': meters_train['grad_norm'].std, 'learning_rate': optimizer.param_groups[0]['lr'], 'time_forward': meters_time['forward'].mean, 'time_backward': meters_time['backward'].mean, 'time_data': meters_time['data'].mean, 'gpu_memory': meters_time['memory'].mean, 'time': time.time(), 'n_iterations': (epoch + 1) * len(ds_iter_train), 'n_datas': (epoch + 1) * args.global_batch_size * len(ds_iter_train), }) for string, meters in zip(('train', 'val'), (meters_train, meters_val)): for k in dict(meters).keys(): log_dict[f'{string}_{k}'] = meters[k].mean log_dict = reduce_dict(log_dict) if get_rank() == 0: log(config=args, model=model, epoch=epoch, log_dict=log_dict, test_dict=test_dict) dist.barrier()
import os import torch from cosypose.utils.distributed import init_distributed_mode, get_world_size, get_tmp_dir, get_rank from cosypose.utils.logging import get_logger logger = get_logger(__name__) if __name__ == '__main__': init_distributed_mode() proc_id = get_rank() n_tasks = get_world_size() n_cpus = os.environ.get('N_CPUS', 'not specified') logger.info(f'Number of processes (=num GPUs): {n_tasks}') logger.info(f'Process ID: {proc_id}') logger.info(f'TMP Directory for this job: {get_tmp_dir()}') logger.info(f'GPU CUDA ID: {torch.cuda.current_device()}') logger.info(f'Max number of CPUs for this process: {n_cpus}')
def train_detector(args): torch.set_num_threads(1) if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id resume_args = yaml.load((resume_dir / 'config.yaml').read_text()) keep_fields = set([ 'resume_run_id', 'epoch_size', ]) vars(args).update({ k: v for k, v in vars(resume_args).items() if k not in keep_fields }) args = check_update_config(args) args.save_dir = EXP_DIR / args.run_id logger.info(f"{'-'*80}") for k, v in args.__dict__.items(): logger.info(f"{k}: {v}") logger.info(f"{'-'*80}") # Initialize distributed device = torch.cuda.current_device() init_distributed_mode() world_size = get_world_size() args.n_gpus = world_size args.global_batch_size = world_size * args.batch_size logger.info(f'Connection established with {world_size} gpus.') # Make train/val datasets def make_datasets(dataset_names): datasets = [] all_labels = set() for (ds_name, n_repeat) in dataset_names: assert 'test' not in ds_name ds = make_scene_dataset(ds_name) logger.info(f'Loaded {ds_name} with {len(ds)} images.') all_labels = all_labels.union(set(ds.all_labels)) for _ in range(n_repeat): datasets.append(ds) return ConcatDataset(datasets), all_labels scene_ds_train, train_labels = make_datasets(args.train_ds_names) scene_ds_val, _ = make_datasets(args.val_ds_names) label_to_category_id = dict() label_to_category_id['background'] = 0 for n, label in enumerate(sorted(list(train_labels)), 1): label_to_category_id[label] = n logger.info( f'Training with {len(label_to_category_id)} categories: {label_to_category_id}' ) args.label_to_category_id = label_to_category_id ds_kwargs = dict( resize=args.input_resize, rgb_augmentation=args.rgb_augmentation, background_augmentation=args.background_augmentation, gray_augmentation=args.gray_augmentation, label_to_category_id=label_to_category_id, ) ds_train = DetectionDataset(scene_ds_train, **ds_kwargs) ds_val = DetectionDataset(scene_ds_val, **ds_kwargs) train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size) ds_iter_train = DataLoader(ds_train, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=collate_fn, drop_last=False, pin_memory=True) ds_iter_train = MultiEpochDataLoader(ds_iter_train) val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size)) ds_iter_val = DataLoader(ds_val, sampler=val_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=collate_fn, drop_last=False, pin_memory=True) ds_iter_val = MultiEpochDataLoader(ds_iter_val) model = create_model_detector(cfg=args, n_classes=len( args.label_to_category_id)).cuda() if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id path = resume_dir / 'checkpoint.pth.tar' logger.info(f'Loading checkpoing from {path}') save = torch.load(path) state_dict = save['state_dict'] model.load_state_dict(state_dict) start_epoch = save['epoch'] + 1 else: start_epoch = 0 end_epoch = args.n_epochs if args.run_id_pretrain is not None: pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar' logger.info(f'Using pretrained model from {pretrain_path}.') model.load_state_dict(torch.load(pretrain_path)['state_dict']) elif args.pretrain_coco: state_dict = load_state_dict_from_url( model_urls['maskrcnn_resnet50_fpn_coco']) keep = lambda k: 'box_predictor' not in k and 'mask_predictor' not in k state_dict = {k: v for k, v in state_dict.items() if keep(k)} model.load_state_dict(state_dict, strict=False) logger.info( 'Using model pre-trained on coco. Removed predictor heads.') else: logger.info('Training MaskRCNN from scratch.') # Synchronize models across processes. model = sync_model(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device) # Optimizer params = [p for p in model.parameters() if p.requires_grad] if args.optimizer.lower() == 'sgd': optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) elif args.optimizer.lower() == 'adam': optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) else: raise ValueError(f'Unknown optimizer {args.optimizer}') # Warmup if args.n_epochs_warmup == 0: lambd = lambda epoch: 1 else: n_batches_warmup = args.n_epochs_warmup * (args.epoch_size // args.batch_size) lambd = lambda batch: (batch + 1) / n_batches_warmup lr_scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambd) lr_scheduler_warmup.last_epoch = start_epoch * args.epoch_size // args.batch_size # LR schedulers # Divide LR by 10 every args.lr_epoch_decay lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.lr_epoch_decay, gamma=0.1, ) lr_scheduler.last_epoch = start_epoch - 1 lr_scheduler.step() for epoch in range(start_epoch, end_epoch): meters_train = defaultdict(AverageValueMeter) meters_val = defaultdict(AverageValueMeter) meters_time = defaultdict(AverageValueMeter) h = functools.partial(h_maskrcnn, model=model, cfg=args) def train_epoch(): model.train() iterator = tqdm(ds_iter_train, ncols=80) t = time.time() for n, sample in enumerate(iterator): if n > 0: meters_time['data'].add(time.time() - t) optimizer.zero_grad() t = time.time() loss = h(data=sample, meters=meters_train) meters_time['forward'].add(time.time() - t) iterator.set_postfix(loss=loss.item()) meters_train['loss_total'].add(loss.item()) t = time.time() loss.backward() total_grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=np.inf, norm_type=2) meters_train['grad_norm'].add( torch.as_tensor(total_grad_norm).item()) optimizer.step() meters_time['backward'].add(time.time() - t) meters_time['memory'].add(torch.cuda.max_memory_allocated() / 1024.**2) if epoch < args.n_epochs_warmup: lr_scheduler_warmup.step() t = time.time() if epoch >= args.n_epochs_warmup: lr_scheduler.step() @torch.no_grad() def validation(): model.train() for sample in tqdm(ds_iter_val, ncols=80): loss = h(data=sample, meters=meters_val) meters_val['loss_total'].add(loss.item()) train_epoch() if epoch % args.val_epoch_interval == 0: validation() test_dict = None if epoch % args.test_epoch_interval == 0: model.eval() test_dict = run_eval(args, model, epoch) log_dict = dict() log_dict.update({ 'grad_norm': meters_train['grad_norm'].mean, 'grad_norm_std': meters_train['grad_norm'].std, 'learning_rate': optimizer.param_groups[0]['lr'], 'time_forward': meters_time['forward'].mean, 'time_backward': meters_time['backward'].mean, 'time_data': meters_time['data'].mean, 'gpu_memory': meters_time['memory'].mean, 'time': time.time(), 'n_iterations': (epoch + 1) * len(ds_iter_train), 'n_datas': (epoch + 1) * args.global_batch_size * len(ds_iter_train), }) for string, meters in zip(('train', 'val'), (meters_train, meters_val)): for k in dict(meters).keys(): log_dict[f'{string}_{k}'] = meters[k].mean log_dict = reduce_dict(log_dict) if get_rank() == 0: log(config=args, model=model, epoch=epoch, log_dict=log_dict, test_dict=test_dict) dist.barrier()