示例#1
0
def test_build_from_cfg():
    BACKBONES = mmcv.Registry('backbone')

    @BACKBONES.register_module
    class ResNet:
        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    @BACKBONES.register_module
    class ResNeXt:
        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type='ResNeXt', depth=50, stages=3)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type=ResNet, depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # non-registered class
    with pytest.raises(KeyError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg['type'] should be a str or class
    with pytest.raises(TypeError):
        cfg = dict(type=1000)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg should contain the key "type"
    with pytest.raises(TypeError):
        cfg = dict(depth=50, stages=4)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # incorrect registry type
    with pytest.raises(TypeError):
        dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

    # incorrect default_args type
    with pytest.raises(TypeError):
        dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)
示例#2
0
def build(cfg, registry, default_args=None):
    """Build module function.

    Args:
        cfg (dict): Configuration for building modules.
        registry (obj): ``registry`` object.
        default_args (dict, optional): Default arguments. Defaults to None.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)

    return build_from_cfg(cfg, registry, default_args)
示例#3
0
def test_multitask_gather():
    ann_info = dict(
        image_size=np.array([256, 256]),
        heatmap_size=np.array([64, 64]),
        num_joints=17,
        joint_weights=np.ones((17, 1), dtype=np.float32),
        use_different_joint_weights=False)

    results = dict(
        joints_3d=np.zeros([17, 3]),
        joints_3d_visible=np.ones([17, 3]),
        ann_info=ann_info)

    pipeline_list = [[dict(type='TopDownGenerateTarget', sigma=2)],
                     [dict(type='TopDownGenerateTargetRegression')]]
    pipeline = dict(
        type='MultitaskGatherTarget',
        pipeline_list=pipeline_list,
        pipeline_indices=[0, 1, 0],
    )
    pipeline = build_from_cfg(pipeline, PIPELINES)

    results = pipeline(results)
    target = results['target']
    target_weight = results['target_weight']
    assert isinstance(target, list)
    assert isinstance(target_weight, list)
    assert target[0].shape == (17, 64, 64)
    assert target_weight[0].shape == (17, 1)
    assert target[1].shape == (17, 2)
    assert target_weight[1].shape == (17, 2)
    assert target[2].shape == (17, 64, 64)
    assert target_weight[2].shape == (17, 1)
示例#4
0
文件: compose.py 项目: microsoft/CtP
 def __init__(self, transform_cfgs: List[dict]):
     self.transforms = []  # type: List[BaseTransform]
     for transform_cfg in transform_cfgs:
         if isinstance(transform_cfg, BaseTransform):
             self.transforms.append(transform_cfg)
         else:
             self.transforms.append(
                 build_from_cfg(transform_cfg, TRANSFORMS))
示例#5
0
 def register_logger_hooks(self, log_config):
     if log_config is None:
         return
     log_interval = log_config['interval']
     for info in log_config['hooks']:
         logger_hook = mmcv.build_from_cfg(
             info, HOOKS, default_args=dict(interval=log_interval))
         self.register_hook(logger_hook, priority='VERY_LOW')
示例#6
0
def test_photometric_distortion_transform():
    data_prefix = 'tests/data/coco/'
    results = dict(image_file=osp.join(data_prefix, '000000000785.jpg'))

    # Define simple pipeline
    load = dict(type='LoadImageFromFile')
    load = build_from_cfg(load, PIPELINES)

    photo_transform = dict(type='PhotometricDistortion')
    photo_transform = build_from_cfg(photo_transform, PIPELINES)

    # Execute transforms
    results = load(results)

    results = photo_transform(results)

    assert results['img'].dtype == np.uint8
示例#7
0
 def register_profiler_hook(self, profiler_config):
     if profiler_config is None:
         return
     if isinstance(profiler_config, dict):
         profiler_config.setdefault('type', 'ProfilerHook')
         hook = mmcv.build_from_cfg(profiler_config, HOOKS)
     else:
         hook = profiler_config
     self.register_hook(hook)
def build_lr_hook(lr_config: Dict[Any, Any]) -> mmcv_hooks.LrUpdaterHook:
    assert "policy" in lr_config, "policy must be specified in lr_config"
    policy_type = lr_config.pop("policy")
    if policy_type == policy_type.lower():
        policy_type = policy_type.title()
    hook_type = policy_type + "LrUpdaterHook"
    lr_config["type"] = hook_type
    hook = mmcv.build_from_cfg(lr_config, mmcv_hooks.HOOKS)
    return hook
示例#9
0
 def register_checkpoint_hook(self, checkpoint_config):
     if checkpoint_config is None:
         return
     if isinstance(checkpoint_config, dict):
         checkpoint_config.setdefault('type', 'CheckpointHook')
         hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
     else:
         hook = checkpoint_config
     self.register_hook(hook)
示例#10
0
 def register_logger_hooks(self, log_config):
     log_interval = log_config['interval']
     for info in log_config['hooks']:
         logger_hook = mmcv.build_from_cfg(
             info,
             HOOKS,
             default_args=dict(interval=log_interval,
                               initial_config=self.things_to_log))
         self.register_hook(logger_hook, priority='VERY_LOW')
示例#11
0
 def register_lr_hook(self, lr_config):
     if isinstance(lr_config, dict):
         assert 'policy' in lr_config
         hook_type = lr_config.pop('policy').title() + 'LrUpdaterHook'
         lr_config['type'] = hook_type
         hook = mmcv.build_from_cfg(lr_config, HOOKS)
     else:
         hook = lr_config
     self.register_hook(hook)
示例#12
0
 def register_timer_hook(self, timer_config):
     if timer_config is None:
         return
     if isinstance(timer_config, dict):
         timer_config_ = copy.deepcopy(timer_config)
         hook = mmcv.build_from_cfg(timer_config_, HOOKS)
     else:
         hook = timer_config
     self.register_hook(hook)
示例#13
0
 def register_optimizer_hook(self, optimizer_config):
     if optimizer_config is None:
         return
     if isinstance(optimizer_config, dict):
         optimizer_config.setdefault('type', 'OptimizerHook')
         hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
     else:
         hook = optimizer_config
     self.register_hook(hook)
示例#14
0
    def __init__(self,
                 info_path,
                 data_root,
                 rate,
                 prepare,
                 sample_groups,
                 classes=None,
                 points_loader=dict(
                     type='LoadPointsFromFile',
                     coord_type='LIDAR',
                     load_dim=4,
                     use_dim=[0, 1, 2, 3])):
        super().__init__()
        self.data_root = data_root
        self.info_path = info_path
        self.rate = rate
        self.prepare = prepare
        self.classes = classes
        self.cat2label = {name: i for i, name in enumerate(classes)}
        self.label2cat = {i: name for i, name in enumerate(classes)}
        self.points_loader = mmcv.build_from_cfg(points_loader, PIPELINES)

        db_infos = mmcv.load(info_path)

        # filter database infos
        from mmdet3d.utils import get_root_logger
        logger = get_root_logger()
        for k, v in db_infos.items():
            logger.info(f'load {len(v)} {k} database infos')
        for prep_func, val in prepare.items():
            db_infos = getattr(self, prep_func)(db_infos, val)
        logger.info('After filter database:')
        for k, v in db_infos.items():
            logger.info(f'load {len(v)} {k} database infos')

        self.db_infos = db_infos

        # load sample groups
        # TODO: more elegant way to load sample groups
        self.sample_groups = []
        for name, num in sample_groups.items():
            self.sample_groups.append({name: int(num)})

        self.group_db_infos = self.db_infos  # just use db_infos
        self.sample_classes = []
        self.sample_max_nums = []
        for group_info in self.sample_groups:
            self.sample_classes += list(group_info.keys())
            self.sample_max_nums += list(group_info.values())

        self.sampler_dict = {}
        for k, v in self.group_db_infos.items():
            self.sampler_dict[k] = BatchSampler(v, k, shuffle=True)
示例#15
0
 def register_momentum_hooks(self, momentum_config):
     if momentum_config is None:
         return
     if isinstance(momentum_config, dict):
         assert 'policy' in momentum_config
         hook_type = momentum_config.pop(
             'policy').title() + 'MomentumUpdaterHook'
         momentum_config['type'] = hook_type
         hook = mmcv.build_from_cfg(momentum_config, HOOKS)
     else:
         hook = momentum_config
     self.register_hook(hook)
示例#16
0
 def register_lr_hook(self, lr_config):
     if isinstance(lr_config, dict):
         assert ('policy' in lr_config)
         policy_type = lr_config.pop('policy')
         if (policy_type == policy_type.lower()):
             policy_type = policy_type.title()
         hook_type = (policy_type + 'LrUpdaterHook')
         lr_config['type'] = hook_type
         hook = mmcv.build_from_cfg(lr_config, HOOKS)
     else:
         hook = lr_config
     self.register_hook(hook)
示例#17
0
    def __init__(self,
                 embed_dim,
                 visual: Union[SwinTransformer],
                 text: TextTransformer,
                 is_token_wise=False,
                 init_scale=1.,
                 pretrained=None):
        super().__init__()

        self.visual = build_from_cfg(visual,
                                     MODELS,
                                     default_args=dict(output_dim=embed_dim,
                                                       init_scale=init_scale))
        self.transformer = build_from_cfg(text,
                                          MODELS,
                                          default_args=dict(
                                              output_dim=embed_dim,
                                              init_scale=init_scale))
        self.is_token_wise = is_token_wise
        if pretrained:
            self.load_pretrain(pretrained)
示例#18
0
 def register_optimizer_hook(self,
                             optimizer_config,
                             priority='NORMAL',
                             optim_type="OptimHookB"):
     if optimizer_config is None:
         return
     if isinstance(optimizer_config, dict):
         optimizer_config.setdefault('type', optim_type)
         hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
     else:
         hook = optimizer_config
     self.register_hook(hook, priority)
示例#19
0
def test_albu_transform():
    data_prefix = 'tests/data/coco/'
    results = dict(image_file=osp.join(data_prefix, '000000000785.jpg'))

    # Define simple pipeline
    load = dict(type='LoadImageFromFile')
    load = build_from_cfg(load, PIPELINES)

    albu_transform = dict(
        type='Albumentation',
        transforms=[
            dict(type='RandomBrightnessContrast', p=0.2),
            dict(type='ToFloat')
        ])
    albu_transform = build_from_cfg(albu_transform, PIPELINES)

    # Execute transforms
    results = load(results)

    results = albu_transform(results)

    assert results['img'].dtype == np.float32
示例#20
0
 def register_momentum_hook(self, momentum_config):
     if (momentum_config is None):
         return
     if isinstance(momentum_config, dict):
         assert ('policy' in momentum_config)
         policy_type = momentum_config.pop('policy')
         if (policy_type == policy_type.lower()):
             policy_type = policy_type.title()
         hook_type = (policy_type + 'MomentumUpdaterHook')
         momentum_config['type'] = hook_type
         hook = mmcv.build_from_cfg(momentum_config, HOOKS)
     else:
         hook = momentum_config
     self.register_hook(hook)
示例#21
0
def test_rename_keys():
    results = dict(
        joints_3d=np.ones([17, 3]), joints_3d_visible=np.ones([17, 3]))
    pipeline = dict(
        type='RenameKeys',
        key_pairs=[('joints_3d', 'target'),
                   ('joints_3d_visible', 'target_weight')])
    pipeline = build_from_cfg(pipeline, PIPELINES)
    results = pipeline(results)
    assert 'joints_3d' not in results
    assert 'joints_3d_visible' not in results
    assert 'target' in results
    assert 'target_weight' in results
    assert results['target'].shape == (17, 3)
    assert results['target_weight'].shape == (17, 3)
示例#22
0
    def register_hook_from_cfg(self, hook_cfg):
        """Register a hook from its cfg.

        Args:
            hook_cfg (dict): Hook config. It should have at least keys 'type'
              and 'priority' indicating its type and priority.

        Notes:
            The specific hook class to register should not use 'type' and
            'priority' arguments during initialization.
        """
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)
示例#23
0
 def register_lr_hook(self, lr_config):
     if isinstance(lr_config, dict):
         assert 'policy' in lr_config
         policy_type = lr_config.pop('policy')
         # If the type of policy is all in lower case, e.g., 'cyclic',
         # then its first letter will be capitalized, e.g., to be 'Cyclic'.
         # This is for the convenient usage of Lr updater.
         # Since this is not applicable for `CosineAnealingLrUpdater`,
         # the string will not be changed if it contains capital letters.
         if policy_type == policy_type.lower():
             policy_type = policy_type.title()
         hook_type = policy_type + 'LrUpdaterHook'
         lr_config['type'] = hook_type
         hook = mmcv.build_from_cfg(lr_config, HOOKS)
     else:
         hook = lr_config
     self.register_hook(hook)
示例#24
0
 def register_momentum_hook(self, momentum_config):
     if momentum_config is None:
         return
     if isinstance(momentum_config, dict):
         assert 'policy' in momentum_config
         policy_type = momentum_config.pop('policy')
         # If the type of policy is all in lower case, e.g., 'cyclic',
         # then its first letter will be capitalized, e.g., to be 'Cyclic'.
         # This is for the convenient usage of momentum updater.
         # Since this is not applicable for
         # `CosineAnnealingMomentumUpdater`,
         # the string will not be changed if it contains capital letters.
         if policy_type == policy_type.lower():
             policy_type = policy_type.title()
         hook_type = policy_type + 'MomentumUpdaterHook'
         momentum_config['type'] = hook_type
         hook = mmcv.build_from_cfg(momentum_config, HOOKS)
     else:
         hook = momentum_config
     self.register_hook(hook, priority=30)
示例#25
0
def build_transform(cfg: dict, default_args: dict = None):
    return build_from_cfg(cfg, TRANSFORMS, default_args=default_args)
示例#26
0
def build_model(cfg: dict, default_args: dict = None):
    return build_from_cfg(cfg, MODELS, default_args=default_args)
示例#27
0
def _dist_train(model,
                dataset,
                cfg,
                validate=False,
                logger=None,
                timestamp=None,
                meta=None):
    """Distributed training function.

    Args:
        model (nn.Module): The model to be trained.
        dataset (:obj:`Dataset`): Train dataset.
        cfg (dict): The config dict for training.
        validate (bool): Whether to do evaluation. Default: False.
        logger (logging.Logger | None): Logger for training. Default: None.
        timestamp (str | None): Local time for runner. Default: None.
        meta (dict | None): Meta dict to record some important information.
            Default: None.
    """
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(ds,
                         cfg.data.samples_per_gpu,
                         cfg.data.workers_per_gpu,
                         dist=True,
                         drop_last=cfg.data.get('drop_last', False),
                         seed=cfg.seed) for ds in dataset
    ]
    # put model on gpus
    find_unused_parameters = cfg.get('find_unused_parameters', False)
    model = DistributedDataParallelWrapper(
        model,
        device_ids=[torch.cuda.current_device()],
        broadcast_buffers=False,
        find_unused_parameters=find_unused_parameters)

    # build runner
    optimizer = build_optimizers(model, cfg.optimizers)
    runner = IterBasedRunner(model,
                             optimizer=optimizer,
                             work_dir=cfg.work_dir,
                             logger=logger,
                             meta=meta)
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp

    # register hooks
    runner.register_training_hooks(cfg.lr_config,
                                   checkpoint_config=cfg.checkpoint_config,
                                   log_config=cfg.log_config)

    # visual hook
    if cfg.get('visual_config', None) is not None:
        cfg.visual_config['output_dir'] = os.path.join(
            cfg.work_dir, cfg.visual_config['output_dir'])
        runner.register_hook(mmcv.build_from_cfg(cfg.visual_config, HOOKS))

    # evaluation hook
    if validate and cfg.get('evaluation', None) is not None:
        dataset = build_dataset(cfg.data.val)
        samples_per_gpu = cfg.data.get('val_samples_per_gpu',
                                       cfg.data.samples_per_gpu)
        workers_per_gpu = cfg.data.get('val_workers_per_gpu',
                                       cfg.data.workers_per_gpu)
        data_loader = build_dataloader(dataset,
                                       samples_per_gpu=samples_per_gpu,
                                       workers_per_gpu=workers_per_gpu,
                                       dist=True,
                                       shuffle=False)
        save_path = osp.join(cfg.work_dir, 'val_visuals')
        runner.register_hook(
            DistEvalIterHook(data_loader,
                             save_path=save_path,
                             **cfg.evaluation))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_iters)
示例#28
0
def build_dataset(cfg: dict, default_args: dict = None):
    return build_from_cfg(cfg, DATASETS, default_args=default_args)
示例#29
0
def test_build_from_cfg():
    BACKBONES = mmcv.Registry('backbone')

    @BACKBONES.register_module()
    class ResNet:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    @BACKBONES.register_module()
    class ResNeXt:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type='ResNeXt', depth=50, stages=3)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type=ResNet, depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # type defined using default_args
    cfg = dict(depth=50)
    model = mmcv.build_from_cfg(
        cfg, BACKBONES, default_args=dict(type='ResNet'))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    # not a registry
    with pytest.raises(TypeError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

    # non-registered class
    with pytest.raises(KeyError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # default_args must be a dict or None
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1)

    # cfg['type'] should be a str or class
    with pytest.raises(TypeError):
        cfg = dict(type=1000)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50, stages=4)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg or default_args should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50)
        model = mmcv.build_from_cfg(
            cfg, BACKBONES, default_args=dict(stages=4))

    # incorrect registry type
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

    # incorrect default_args type
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)

    # incorrect arguments
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', non_existing_arg=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES)
示例#30
0
def build_backbone(cfg: dict, default_args: dict = None):
    return build_from_cfg(cfg, BACKBONES, default_args=default_args)