示例#1
0
    def __call__(self, model):
        if hasattr(model, 'module'):
            model = model.module

        optimizer_cfg = self.optimizer_cfg.copy()
        # if no paramwise option is specified, just use the global setting
        if not self.paramwise_cfg:
            optimizer_cfg['params'] = model.parameters()
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)

        # set param-wise lr and weight decay recursively
        params = []
        self.add_params(params, model)
        optimizer_cfg['params'] = params

        return build_from_cfg(optimizer_cfg, OPTIMIZERS)
示例#2
0
 def __init__(self, transforms):
     assert isinstance(transforms, collections.abc.Sequence)
     self.transforms = []
     for transform in transforms:
         if isinstance(transform, dict):
             transform = build_from_cfg(transform, PIPELINES)
             self.transforms.append(transform)
         elif callable(transform):
             self.transforms.append(transform)
         else:
             raise TypeError('transform must be callable or a dict')
示例#3
0
def build(cfg, registry, default_args=None):
    """Build a module.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a dict
            or a list of configs.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)
示例#4
0
def build_dataset(cfg, default_args=None):
    from .dataset_wrappers import (RepeatDataset)
    if cfg['type'] == 'RepeatDataset':
        # 将数据集重复times次数,主要用于小数据集,否则每个epoch太短,训练时长会变长
        dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args),
                                cfg['times'])
    elif isinstance(cfg.get('ann_file'), (list, tuple)):
        # 多个标注文件
        dataset = _concat_dataset(cfg, default_args)
    else:
        dataset = build_from_cfg(cfg, DATASETS, default_args)

    return dataset
示例#5
0
def build_anchor_generator(cfg, default_args=None):
    return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
示例#6
0
def build_iou_calculator(cfg, default_args=None):
    """Builder of IoU calculator."""
    return build_from_cfg(cfg, IOU_CALCULATORS, default_args)
示例#7
0
def build_assigner(cfg, **default_args):
    """Builder of box assigner."""
    return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)
示例#8
0
def build_bbox_coder(cfg, **default_args):
    """Builder of box coder."""
    return build_from_cfg(cfg, BBOX_CODERS, default_args)
示例#9
0
def build_sampler(cfg, **default_args):
    """Builder of box sampler."""
    return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)