def __init__(self, scene_ds, meters, batch_size=64, cache_data=True, n_workers=4, sampler=None): self.rank = get_rank() self.world_size = get_world_size() self.tmp_dir = get_tmp_dir() self.scene_ds = scene_ds if sampler is None: sampler = DistributedSceneSampler(scene_ds, num_replicas=self.world_size, rank=self.rank) dataloader = DataLoader(scene_ds, batch_size=batch_size, num_workers=n_workers, sampler=sampler, collate_fn=self.collate_fn) if cache_data: self.dataloader = list(tqdm(dataloader)) else: self.dataloader = dataloader self.meters = meters
def gather_distributed(self, tmp_dir): tmp_dir = Path(tmp_dir) tmp_dir.mkdir(exist_ok=True, parents=True) rank, world_size = get_rank(), get_world_size() tmp_file_template = (tmp_dir / 'rank={rank}.pth.tar').as_posix() if rank > 0: tmp_file = tmp_file_template.format(rank=rank) torch.save(self.datas, tmp_file) if world_size > 1: torch.distributed.barrier() if rank == 0 and world_size > 1: all_datas = self.datas for n in range(1, world_size): tmp_file = tmp_file_template.format(rank=n) datas = torch.load(tmp_file) for k in all_datas.keys(): all_datas[k].extend(datas.get(k, [])) Path(tmp_file).unlink() self.datas = all_datas if world_size > 1: torch.distributed.barrier() return
def __init__(self, scene_ds, cache_data=False, n_workers=4): self.rank = get_rank() self.world_size = get_world_size() self.tmp_dir = get_tmp_dir() assert self.world_size == 1 self.sampler = ListSampler(np.argsort(scene_ds.frame_index['view_id'])) dataloader = DataLoader(scene_ds, batch_size=1, num_workers=n_workers, sampler=self.sampler, collate_fn=self.collate_fn) if cache_data: self.dataloader = list(tqdm(dataloader)) else: self.dataloader = dataloader
def __init__(self, scene_ds, batch_size=4, cache_data=False, n_workers=4): self.rank = get_rank() self.world_size = get_world_size() self.tmp_dir = get_tmp_dir() sampler = DistributedSceneSampler(scene_ds, num_replicas=self.world_size, rank=self.rank) self.sampler = sampler dataloader = DataLoader(scene_ds, batch_size=batch_size, num_workers=n_workers, sampler=sampler, collate_fn=self.collate_fn) if cache_data: self.dataloader = list(tqdm(dataloader)) else: self.dataloader = dataloader
def gather_distributed(self, tmp_dir=None): rank, world_size = get_rank(), get_world_size() tmp_file_template = (tmp_dir / 'rank={rank}.pth.tar').as_posix() if rank > 0: tmp_file = tmp_file_template.format(rank=rank) torch.save(self, tmp_file) if world_size > 1: torch.distributed.barrier() datas = [self] if rank == 0 and world_size > 1: for n in range(1, world_size): tmp_file = tmp_file_template.format(rank=n) data = torch.load(tmp_file) datas.append(data) Path(tmp_file).unlink() if world_size > 1: torch.distributed.barrier() return concatenate(datas)
def add(self, pred_data, gt_data): # ArticulatedObjectData pred_data = pred_data.float() gt_data = gt_data.float() TXO_gt = gt_data.poses q_gt = gt_data.joints gt_infos = gt_data.infos gt_infos['valid'] = True TXO_pred = pred_data.poses q_pred = pred_data.joints pred_infos = pred_data.infos cand_infos = get_candidate_matches(pred_infos, gt_infos) cand_TXO_gt = TXO_gt[cand_infos['gt_id']] cand_TXO_pred = TXO_pred[cand_infos['pred_id']] loc_errors_xyz, rot_errors = loc_rot_pose_errors( cand_TXO_pred, cand_TXO_gt) loc_errors_norm = torch.norm(loc_errors_xyz, dim=-1, p=2) error = loc_errors_norm.cpu().numpy().astype(np.float) error[np.isnan(error)] = 1000 cand_infos['error'] = error matches = match_poses(cand_infos) n_gt = len(gt_infos) def empty_array(shape, default='nan', dtype=np.float): return np.empty(shape, dtype=dtype) * float(default) gt_infos['rank'] = get_rank() gt_infos['world_size'] = get_world_size() df = xr.Dataset(gt_infos).rename(dict(dim_0='gt_object_id')) scores = empty_array(n_gt) scores[matches.gt_id] = pred_infos.loc[matches.pred_id, 'score'] df['pred_score'] = 'gt_object_id', scores rot_errors_ = empty_array((n_gt, 3)) matches_rot_errors = rot_errors[matches.cand_id].cpu().numpy() matches_rot_errors = np.abs((matches_rot_errors + 180.0) % 360.0 - 180.0) rot_errors_[matches.gt_id] = matches_rot_errors df['rot_error'] = ('gt_object_id', 'ypr'), rot_errors_ loc_errors_xyz_ = empty_array((n_gt, 3)) loc_errors_xyz_[matches.gt_id] = loc_errors_xyz[ matches.cand_id].cpu().numpy() df['loc_error_xyz'] = ('gt_object_id', 'xyz'), loc_errors_xyz_ loc_errors_norm_ = empty_array((n_gt)) loc_errors_norm_[matches.gt_id] = loc_errors_norm[ matches.cand_id].cpu().numpy() df['loc_error_norm'] = ('gt_object_id', ), loc_errors_norm_ q_errors = empty_array((n_gt, q_gt.shape[1])) matches_q_errors = (q_gt[matches.gt_id] - q_pred[matches.pred_id] ).cpu().numpy() * 180 / np.pi matches_q_errors = np.abs((matches_q_errors + 180.0) % 360.0 - 180.0) q_errors[matches.gt_id] = matches_q_errors df['joint_error'] = ('gt_object_id', 'dofs'), q_errors self.datas['df'].append(df)
import os import torch from robopose.utils.distributed import init_distributed_mode, get_world_size, get_tmp_dir, get_rank from robopose.utils.logging import get_logger logger = get_logger(__name__) if __name__ == '__main__': init_distributed_mode() proc_id = get_rank() n_tasks = get_world_size() n_cpus = os.environ.get('N_CPUS', 'not specified') logger.info(f'Number of processes (=num GPUs): {n_tasks}') logger.info(f'Process ID: {proc_id}') logger.info(f'TMP Directory for this job: {get_tmp_dir()}') logger.info(f'GPU CUDA ID: {torch.cuda.current_device()}') logger.info(f'Max number of CPUs for this process: {n_cpus}')
def train_pose(args): torch.set_num_threads(1) if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id resume_args = yaml.load((resume_dir / 'config.yaml').read_text()) keep_fields = set( ['resume', 'resume_path', 'epoch_size', 'resume_run_id']) vars(args).update({ k: v for k, v in vars(resume_args).items() if k not in keep_fields }) args.save_dir = EXP_DIR / args.run_id args = check_update_config(args) logger.info(f"{'-'*80}") for k, v in args.__dict__.items(): logger.info(f"{k}: {v}") logger.info(f"{'-'*80}") # Initialize distributed device = torch.cuda.current_device() init_distributed_mode() if get_rank() == 0: tmp_dir = get_tmp_dir() (tmp_dir / 'config.yaml').write_text(yaml.dump(args)) world_size = get_world_size() args.n_gpus = world_size args.global_batch_size = world_size * args.batch_size logger.info(f'Connection established with {world_size} gpus.') # Make train/val datasets def make_datasets(dataset_names): datasets = [] for (ds_name, n_repeat) in dataset_names: ds = make_scene_dataset(ds_name) logger.info(f'Loaded {ds_name} with {len(ds)} images.') for _ in range(n_repeat): datasets.append(ds) return ConcatDataset(datasets) scene_ds_train = make_datasets(args.train_ds_names) scene_ds_val = make_datasets(args.val_ds_names) ds_kwargs = dict( resize=args.input_resize, rgb_augmentation=args.rgb_augmentation, background_augmentation=args.background_augmentation, ) ds_train = ArticulatedDataset(scene_ds_train, **ds_kwargs) ds_val = ArticulatedDataset(scene_ds_val, **ds_kwargs) train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size) ds_iter_train = DataLoader(ds_train, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=ds_train.collate_fn, drop_last=False, pin_memory=True) ds_iter_train = MultiEpochDataLoader(ds_iter_train) val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size)) ds_iter_val = DataLoader(ds_val, sampler=val_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=ds_val.collate_fn, drop_last=False, pin_memory=True) ds_iter_val = MultiEpochDataLoader(ds_iter_val) # Make model renderer = BulletBatchRenderer(object_set=args.urdf_ds_name, n_workers=args.n_rendering_workers) urdf_ds = make_urdf_dataset(args.urdf_ds_name) mesh_db = MeshDataBase.from_urdf_ds(urdf_ds).cuda().float() model = create_model(cfg=args, renderer=renderer, mesh_db=mesh_db).cuda() if args.run_id_pretrain is not None: pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar' logger.info(f'Using pretrained model from {pretrain_path}.') model.load_state_dict(torch.load(pretrain_path)['state_dict']) if args.run_id_pretrain_backbone is not None and get_rank() == 0: pretrain_path = EXP_DIR / args.run_id_pretrain_backbone / 'checkpoint.pth.tar' logger.info(f'Using pretrained backbone from {pretrain_path}.') pretrain_state_dict = torch.load(pretrain_path)['state_dict'] model_state_dict = model.state_dict() conv1_key = 'backbone.conv1.weight' if model_state_dict[conv1_key].shape[1] != pretrain_state_dict[ conv1_key].shape[1]: logger.info('Using inflated input layer') logger.info( f'Original size: {pretrain_state_dict[conv1_key].shape}') logger.info(f'Target size: {model_state_dict[conv1_key].shape}') pretrain_n_inputs = pretrain_state_dict[conv1_key].shape[1] model_n_inputs = model_state_dict[conv1_key].shape[1] conv1_weight = pretrain_state_dict[conv1_key] weight_inflated = torch.cat([ conv1_weight, conv1_weight[:, [0]].repeat( 1, model_n_inputs - pretrain_n_inputs, 1, 1) ], axis=1) pretrain_state_dict[conv1_key] = weight_inflated.clone() pretrain_state_dict = { k: v for k, v in pretrain_state_dict.items() if ('backbone' in k and k in model_state_dict) } logger.info(f"Pretrain keys: {list(pretrain_state_dict.keys())}") model.load_state_dict(pretrain_state_dict, strict=False) if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id path = resume_dir / 'checkpoint.pth.tar' logger.info(f'Loading checkpoing from {path}') save = torch.load(path) state_dict = save['state_dict'] model.load_state_dict(state_dict) start_epoch = save['epoch'] + 1 else: start_epoch = 0 end_epoch = args.n_epochs # Synchronize models across processes. model = sync_model(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device) # Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # Warmup def get_lr_ratio(batch): n_batch_per_epoch = args.epoch_size // args.batch_size epoch_id = batch // n_batch_per_epoch if args.n_epochs_warmup == 0: lr_ratio = 1.0 else: n_batches_warmup = args.n_epochs_warmup * (args.epoch_size // args.batch_size) lr_ratio = min(max(batch, 1) / n_batches_warmup, 1.0) lr_ratio /= 10**(epoch_id // args.lr_epoch_decay) return lr_ratio lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr_ratio) lr_scheduler.last_epoch = start_epoch * args.epoch_size // args.batch_size - 1 # Just remove the annoying warning optimizer._step_count = 1 lr_scheduler.step() optimizer._step_count = 0 for epoch in range(start_epoch, end_epoch + 1): meters_train = defaultdict(lambda: AverageValueMeter()) meters_val = defaultdict(lambda: AverageValueMeter()) meters_time = defaultdict(lambda: AverageValueMeter()) if args.add_iteration_epoch_interval is None: n_iterations = args.n_iterations else: n_iterations = min(epoch // args.add_iteration_epoch_interval + 1, args.n_iterations) h = functools.partial(h_pose, model=model, cfg=args, n_iterations=n_iterations, mesh_db=mesh_db) def train_epoch(): model.train() iterator = tqdm(ds_iter_train, ncols=80) t = time.time() for n, sample in enumerate(iterator): if n > 0: meters_time['data'].add(time.time() - t) optimizer.zero_grad() t = time.time() loss = h(data=sample, meters=meters_train, train=True) meters_time['forward'].add(time.time() - t) iterator.set_postfix(loss=loss.item()) meters_train['loss_total'].add(loss.item()) t = time.time() loss.backward() total_grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=args.clip_grad_norm, norm_type=2) meters_train['grad_norm'].add( torch.as_tensor(total_grad_norm).item()) optimizer.step() meters_time['backward'].add(time.time() - t) meters_time['memory'].add(torch.cuda.max_memory_allocated() / 1024.**2) t = time.time() lr_scheduler.step() @torch.no_grad() def validation(): model.eval() for sample in tqdm(ds_iter_val, ncols=80): loss = h(data=sample, meters=meters_val, train=False) meters_val['loss_total'].add(loss.item()) @torch.no_grad() def test(): model.eval() return run_test(args, epoch=epoch) train_epoch() if epoch % args.val_epoch_interval == 0: validation() log_dict = dict() log_dict.update({ 'grad_norm': meters_train['grad_norm'].mean, 'grad_norm_std': meters_train['grad_norm'].std, 'learning_rate': optimizer.param_groups[0]['lr'], 'time_forward': meters_time['forward'].mean, 'time_backward': meters_time['backward'].mean, 'time_data': meters_time['data'].mean, 'gpu_memory': meters_time['memory'].mean, 'time': time.time(), 'n_iterations': (epoch + 1) * len(ds_iter_train), 'n_datas': (epoch + 1) * args.global_batch_size * len(ds_iter_train), }) for string, meters in zip(('train', 'val'), (meters_train, meters_val)): for k in dict(meters).keys(): log_dict[f'{string}_{k}'] = meters[k].mean log_dict = reduce_dict(log_dict) if get_rank() == 0: log(config=args, model=model, epoch=epoch, log_dict=None, test_dict=None) dist.barrier() test_dict = None if args.test_only_last_epoch: if epoch == end_epoch: test_dict = test() else: if epoch % args.test_epoch_interval == 0: test_dict = test() if get_rank() == 0: log(config=args, model=model, epoch=epoch, log_dict=log_dict, test_dict=test_dict) dist.barrier()