Exemplo n.º 1
0
    def __init__(self,
                 scene_ds,
                 meters,
                 batch_size=64,
                 cache_data=True,
                 n_workers=4,
                 sampler=None):

        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()

        self.scene_ds = scene_ds
        if sampler is None:
            sampler = DistributedSceneSampler(scene_ds,
                                              num_replicas=self.world_size,
                                              rank=self.rank)
        dataloader = DataLoader(scene_ds,
                                batch_size=batch_size,
                                num_workers=n_workers,
                                sampler=sampler,
                                collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader

        self.meters = meters
Exemplo n.º 2
0
    def gather_distributed(self, tmp_dir):
        tmp_dir = Path(tmp_dir)
        tmp_dir.mkdir(exist_ok=True, parents=True)
        rank, world_size = get_rank(), get_world_size()
        tmp_file_template = (tmp_dir / 'rank={rank}.pth.tar').as_posix()
        if rank > 0:
            tmp_file = tmp_file_template.format(rank=rank)
            torch.save(self.datas, tmp_file)

        if world_size > 1:
            torch.distributed.barrier()

        if rank == 0 and world_size > 1:
            all_datas = self.datas
            for n in range(1, world_size):
                tmp_file = tmp_file_template.format(rank=n)
                datas = torch.load(tmp_file)
                for k in all_datas.keys():
                    all_datas[k].extend(datas.get(k, []))
                Path(tmp_file).unlink()
            self.datas = all_datas

        if world_size > 1:
            torch.distributed.barrier()
        return
Exemplo n.º 3
0
    def __init__(self, scene_ds, cache_data=False, n_workers=4):

        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()
        assert self.world_size == 1

        self.sampler = ListSampler(np.argsort(scene_ds.frame_index['view_id']))
        dataloader = DataLoader(scene_ds,
                                batch_size=1,
                                num_workers=n_workers,
                                sampler=self.sampler,
                                collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader
Exemplo n.º 4
0
    def __init__(self, scene_ds, batch_size=4, cache_data=False, n_workers=4):

        self.rank = get_rank()
        self.world_size = get_world_size()
        self.tmp_dir = get_tmp_dir()

        sampler = DistributedSceneSampler(scene_ds,
                                          num_replicas=self.world_size,
                                          rank=self.rank)
        self.sampler = sampler
        dataloader = DataLoader(scene_ds,
                                batch_size=batch_size,
                                num_workers=n_workers,
                                sampler=sampler,
                                collate_fn=self.collate_fn)

        if cache_data:
            self.dataloader = list(tqdm(dataloader))
        else:
            self.dataloader = dataloader
Exemplo n.º 5
0
    def gather_distributed(self, tmp_dir=None):
        rank, world_size = get_rank(), get_world_size()
        tmp_file_template = (tmp_dir / 'rank={rank}.pth.tar').as_posix()

        if rank > 0:
            tmp_file = tmp_file_template.format(rank=rank)
            torch.save(self, tmp_file)

        if world_size > 1:
            torch.distributed.barrier()

        datas = [self]
        if rank == 0 and world_size > 1:
            for n in range(1, world_size):
                tmp_file = tmp_file_template.format(rank=n)
                data = torch.load(tmp_file)
                datas.append(data)
                Path(tmp_file).unlink()

        if world_size > 1:
            torch.distributed.barrier()
        return concatenate(datas)
Exemplo n.º 6
0
    def add(self, pred_data, gt_data):
        # ArticulatedObjectData
        pred_data = pred_data.float()
        gt_data = gt_data.float()

        TXO_gt = gt_data.poses
        q_gt = gt_data.joints
        gt_infos = gt_data.infos
        gt_infos['valid'] = True

        TXO_pred = pred_data.poses
        q_pred = pred_data.joints
        pred_infos = pred_data.infos

        cand_infos = get_candidate_matches(pred_infos, gt_infos)
        cand_TXO_gt = TXO_gt[cand_infos['gt_id']]
        cand_TXO_pred = TXO_pred[cand_infos['pred_id']]

        loc_errors_xyz, rot_errors = loc_rot_pose_errors(
            cand_TXO_pred, cand_TXO_gt)
        loc_errors_norm = torch.norm(loc_errors_xyz, dim=-1, p=2)

        error = loc_errors_norm.cpu().numpy().astype(np.float)
        error[np.isnan(error)] = 1000
        cand_infos['error'] = error
        matches = match_poses(cand_infos)

        n_gt = len(gt_infos)

        def empty_array(shape, default='nan', dtype=np.float):
            return np.empty(shape, dtype=dtype) * float(default)

        gt_infos['rank'] = get_rank()
        gt_infos['world_size'] = get_world_size()
        df = xr.Dataset(gt_infos).rename(dict(dim_0='gt_object_id'))

        scores = empty_array(n_gt)
        scores[matches.gt_id] = pred_infos.loc[matches.pred_id, 'score']
        df['pred_score'] = 'gt_object_id', scores

        rot_errors_ = empty_array((n_gt, 3))
        matches_rot_errors = rot_errors[matches.cand_id].cpu().numpy()
        matches_rot_errors = np.abs((matches_rot_errors + 180.0) % 360.0 -
                                    180.0)
        rot_errors_[matches.gt_id] = matches_rot_errors
        df['rot_error'] = ('gt_object_id', 'ypr'), rot_errors_

        loc_errors_xyz_ = empty_array((n_gt, 3))
        loc_errors_xyz_[matches.gt_id] = loc_errors_xyz[
            matches.cand_id].cpu().numpy()
        df['loc_error_xyz'] = ('gt_object_id', 'xyz'), loc_errors_xyz_

        loc_errors_norm_ = empty_array((n_gt))
        loc_errors_norm_[matches.gt_id] = loc_errors_norm[
            matches.cand_id].cpu().numpy()
        df['loc_error_norm'] = ('gt_object_id', ), loc_errors_norm_

        q_errors = empty_array((n_gt, q_gt.shape[1]))
        matches_q_errors = (q_gt[matches.gt_id] - q_pred[matches.pred_id]
                            ).cpu().numpy() * 180 / np.pi
        matches_q_errors = np.abs((matches_q_errors + 180.0) % 360.0 - 180.0)
        q_errors[matches.gt_id] = matches_q_errors

        df['joint_error'] = ('gt_object_id', 'dofs'), q_errors
        self.datas['df'].append(df)
Exemplo n.º 7
0
import os
import torch
from robopose.utils.distributed import init_distributed_mode, get_world_size, get_tmp_dir, get_rank
from robopose.utils.logging import get_logger
logger = get_logger(__name__)

if __name__ == '__main__':
    init_distributed_mode()
    proc_id = get_rank()
    n_tasks = get_world_size()
    n_cpus = os.environ.get('N_CPUS', 'not specified')
    logger.info(f'Number of processes (=num GPUs): {n_tasks}')
    logger.info(f'Process ID: {proc_id}')
    logger.info(f'TMP Directory for this job: {get_tmp_dir()}')
    logger.info(f'GPU CUDA ID: {torch.cuda.current_device()}')
    logger.info(f'Max number of CPUs for this process: {n_cpus}')
Exemplo n.º 8
0
def train_pose(args):
    torch.set_num_threads(1)

    if args.resume_run_id:
        resume_dir = EXP_DIR / args.resume_run_id
        resume_args = yaml.load((resume_dir / 'config.yaml').read_text())
        keep_fields = set(
            ['resume', 'resume_path', 'epoch_size', 'resume_run_id'])
        vars(args).update({
            k: v
            for k, v in vars(resume_args).items() if k not in keep_fields
        })

    args.save_dir = EXP_DIR / args.run_id
    args = check_update_config(args)

    logger.info(f"{'-'*80}")
    for k, v in args.__dict__.items():
        logger.info(f"{k}: {v}")
    logger.info(f"{'-'*80}")

    # Initialize distributed
    device = torch.cuda.current_device()
    init_distributed_mode()
    if get_rank() == 0:
        tmp_dir = get_tmp_dir()
        (tmp_dir / 'config.yaml').write_text(yaml.dump(args))

    world_size = get_world_size()
    args.n_gpus = world_size
    args.global_batch_size = world_size * args.batch_size
    logger.info(f'Connection established with {world_size} gpus.')

    # Make train/val datasets
    def make_datasets(dataset_names):
        datasets = []
        for (ds_name, n_repeat) in dataset_names:
            ds = make_scene_dataset(ds_name)
            logger.info(f'Loaded {ds_name} with {len(ds)} images.')
            for _ in range(n_repeat):
                datasets.append(ds)
        return ConcatDataset(datasets)

    scene_ds_train = make_datasets(args.train_ds_names)
    scene_ds_val = make_datasets(args.val_ds_names)

    ds_kwargs = dict(
        resize=args.input_resize,
        rgb_augmentation=args.rgb_augmentation,
        background_augmentation=args.background_augmentation,
    )
    ds_train = ArticulatedDataset(scene_ds_train, **ds_kwargs)
    ds_val = ArticulatedDataset(scene_ds_val, **ds_kwargs)

    train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size)
    ds_iter_train = DataLoader(ds_train,
                               sampler=train_sampler,
                               batch_size=args.batch_size,
                               num_workers=args.n_dataloader_workers,
                               collate_fn=ds_train.collate_fn,
                               drop_last=False,
                               pin_memory=True)
    ds_iter_train = MultiEpochDataLoader(ds_iter_train)

    val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size))
    ds_iter_val = DataLoader(ds_val,
                             sampler=val_sampler,
                             batch_size=args.batch_size,
                             num_workers=args.n_dataloader_workers,
                             collate_fn=ds_val.collate_fn,
                             drop_last=False,
                             pin_memory=True)
    ds_iter_val = MultiEpochDataLoader(ds_iter_val)

    # Make model
    renderer = BulletBatchRenderer(object_set=args.urdf_ds_name,
                                   n_workers=args.n_rendering_workers)
    urdf_ds = make_urdf_dataset(args.urdf_ds_name)
    mesh_db = MeshDataBase.from_urdf_ds(urdf_ds).cuda().float()

    model = create_model(cfg=args, renderer=renderer, mesh_db=mesh_db).cuda()

    if args.run_id_pretrain is not None:
        pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar'
        logger.info(f'Using pretrained model from {pretrain_path}.')
        model.load_state_dict(torch.load(pretrain_path)['state_dict'])

    if args.run_id_pretrain_backbone is not None and get_rank() == 0:
        pretrain_path = EXP_DIR / args.run_id_pretrain_backbone / 'checkpoint.pth.tar'
        logger.info(f'Using pretrained backbone from {pretrain_path}.')
        pretrain_state_dict = torch.load(pretrain_path)['state_dict']

        model_state_dict = model.state_dict()
        conv1_key = 'backbone.conv1.weight'
        if model_state_dict[conv1_key].shape[1] != pretrain_state_dict[
                conv1_key].shape[1]:
            logger.info('Using inflated input layer')
            logger.info(
                f'Original size: {pretrain_state_dict[conv1_key].shape}')
            logger.info(f'Target size: {model_state_dict[conv1_key].shape}')
            pretrain_n_inputs = pretrain_state_dict[conv1_key].shape[1]
            model_n_inputs = model_state_dict[conv1_key].shape[1]
            conv1_weight = pretrain_state_dict[conv1_key]
            weight_inflated = torch.cat([
                conv1_weight, conv1_weight[:, [0]].repeat(
                    1, model_n_inputs - pretrain_n_inputs, 1, 1)
            ],
                                        axis=1)
            pretrain_state_dict[conv1_key] = weight_inflated.clone()

        pretrain_state_dict = {
            k: v
            for k, v in pretrain_state_dict.items()
            if ('backbone' in k and k in model_state_dict)
        }
        logger.info(f"Pretrain keys: {list(pretrain_state_dict.keys())}")
        model.load_state_dict(pretrain_state_dict, strict=False)

    if args.resume_run_id:
        resume_dir = EXP_DIR / args.resume_run_id
        path = resume_dir / 'checkpoint.pth.tar'
        logger.info(f'Loading checkpoing from {path}')
        save = torch.load(path)
        state_dict = save['state_dict']
        model.load_state_dict(state_dict)
        start_epoch = save['epoch'] + 1
    else:
        start_epoch = 0
    end_epoch = args.n_epochs

    # Synchronize models across processes.
    model = sync_model(model)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[device],
                                                      output_device=device)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # Warmup
    def get_lr_ratio(batch):
        n_batch_per_epoch = args.epoch_size // args.batch_size
        epoch_id = batch // n_batch_per_epoch

        if args.n_epochs_warmup == 0:
            lr_ratio = 1.0
        else:
            n_batches_warmup = args.n_epochs_warmup * (args.epoch_size //
                                                       args.batch_size)
            lr_ratio = min(max(batch, 1) / n_batches_warmup, 1.0)

        lr_ratio /= 10**(epoch_id // args.lr_epoch_decay)
        return lr_ratio

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr_ratio)
    lr_scheduler.last_epoch = start_epoch * args.epoch_size // args.batch_size - 1

    # Just remove the annoying warning
    optimizer._step_count = 1
    lr_scheduler.step()
    optimizer._step_count = 0

    for epoch in range(start_epoch, end_epoch + 1):
        meters_train = defaultdict(lambda: AverageValueMeter())
        meters_val = defaultdict(lambda: AverageValueMeter())
        meters_time = defaultdict(lambda: AverageValueMeter())

        if args.add_iteration_epoch_interval is None:
            n_iterations = args.n_iterations
        else:
            n_iterations = min(epoch // args.add_iteration_epoch_interval + 1,
                               args.n_iterations)
        h = functools.partial(h_pose,
                              model=model,
                              cfg=args,
                              n_iterations=n_iterations,
                              mesh_db=mesh_db)

        def train_epoch():
            model.train()
            iterator = tqdm(ds_iter_train, ncols=80)
            t = time.time()
            for n, sample in enumerate(iterator):
                if n > 0:
                    meters_time['data'].add(time.time() - t)

                optimizer.zero_grad()

                t = time.time()
                loss = h(data=sample, meters=meters_train, train=True)
                meters_time['forward'].add(time.time() - t)
                iterator.set_postfix(loss=loss.item())
                meters_train['loss_total'].add(loss.item())

                t = time.time()
                loss.backward()
                total_grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    max_norm=args.clip_grad_norm,
                    norm_type=2)
                meters_train['grad_norm'].add(
                    torch.as_tensor(total_grad_norm).item())

                optimizer.step()
                meters_time['backward'].add(time.time() - t)
                meters_time['memory'].add(torch.cuda.max_memory_allocated() /
                                          1024.**2)

                t = time.time()

                lr_scheduler.step()

        @torch.no_grad()
        def validation():
            model.eval()
            for sample in tqdm(ds_iter_val, ncols=80):
                loss = h(data=sample, meters=meters_val, train=False)
                meters_val['loss_total'].add(loss.item())

        @torch.no_grad()
        def test():
            model.eval()
            return run_test(args, epoch=epoch)

        train_epoch()

        if epoch % args.val_epoch_interval == 0:
            validation()

        log_dict = dict()
        log_dict.update({
            'grad_norm':
            meters_train['grad_norm'].mean,
            'grad_norm_std':
            meters_train['grad_norm'].std,
            'learning_rate':
            optimizer.param_groups[0]['lr'],
            'time_forward':
            meters_time['forward'].mean,
            'time_backward':
            meters_time['backward'].mean,
            'time_data':
            meters_time['data'].mean,
            'gpu_memory':
            meters_time['memory'].mean,
            'time':
            time.time(),
            'n_iterations': (epoch + 1) * len(ds_iter_train),
            'n_datas':
            (epoch + 1) * args.global_batch_size * len(ds_iter_train),
        })

        for string, meters in zip(('train', 'val'),
                                  (meters_train, meters_val)):
            for k in dict(meters).keys():
                log_dict[f'{string}_{k}'] = meters[k].mean

        log_dict = reduce_dict(log_dict)
        if get_rank() == 0:
            log(config=args,
                model=model,
                epoch=epoch,
                log_dict=None,
                test_dict=None)

        dist.barrier()

        test_dict = None
        if args.test_only_last_epoch:
            if epoch == end_epoch:
                test_dict = test()
        else:
            if epoch % args.test_epoch_interval == 0:
                test_dict = test()

        if get_rank() == 0:
            log(config=args,
                model=model,
                epoch=epoch,
                log_dict=log_dict,
                test_dict=test_dict)

        dist.barrier()