Ejemplo n.º 1
0
 def register_checkpoint_hook(self, checkpoint_config):
     if isinstance(checkpoint_config, dict):
         checkpoint_config.setdefault('type', 'CheckpointHook')
         hook = build_from_cfg(checkpoint_config, HOOKS)
     else:
         hook = checkpoint_config
     self.register_hook(hook)
Ejemplo n.º 2
0
def build_dataset(cfg, default_args=None):
    """Build a dataset from config dict.

    It supports a variety of dataset config. If ``cfg`` is a Sequential (list
    or dict), it will be a concatenated dataset of the datasets specified by
    the Sequential. If it is a ``RepeatDataset``, then it will repeat the
    dataset ``cfg['dataset']`` for ``cfg['times']`` times. If the ``ann_file``
    of the dataset is a Sequential, then it will build a concatenated dataset
    with the same dataset type but different ``ann_file``.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        default_args (dict, optional): Default initialization arguments.
            Default: None.

    Returns:
        Dataset: The constructed dataset.
    """
    if isinstance(cfg, (list, tuple)):
        raise NotImplementedError(
            "dose not support list(tuple) configs for dataset build now")
    elif cfg['type'] == 'RepeatDataset':
        dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args),
                                cfg['times'])
    elif isinstance(cfg.get('ann_file'), (list, tuple)):
        raise NotImplementedError(
            "does not support list(tuple) ann_files for dataset build now")
    else:
        dataset = build_from_cfg(cfg, DATASETS, default_args)

    return dataset
Ejemplo n.º 3
0
 def register_optimizer_hook(self, optimizer_config):
     if optimizer_config is None:
         return
     if isinstance(optimizer_config, dict):
         optimizer_config.setdefault('type', 'OptimizerHook')
         hook = build_from_cfg(optimizer_config, HOOKS)
     else:
         hook = optimizer_config
     self.register_hook(hook)
Ejemplo n.º 4
0
    def __call__(self, model):
        optimizer_cfg = self.optimizer_cfg.copy()
        # if no paramwise option is specified, just use the global setting
        logger = get_root_logger()
        param_nums = 0
        for item in model.parameters():
            param_nums += np.prod(np.array(item.shape))
        logger.info("model: {} 's total parameter nums: {}".format(model.__class__.__name__, param_nums))
        
        if not self.paramwise_cfg:
            optimizer_cfg['params'] = model.parameters()
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)
        else:
            raise NotImplementedError("paramwise_cfg not implemented now")

            # set param-wise lr and weight decay recursively
            params = []
            self.add_params(params, model)
            optimizer_cfg['params'] = params

            return build_from_cfg(optimizer_cfg, OPTIMIZERS)
Ejemplo n.º 5
0
 def __init__(self, transforms):
     assert isinstance(transforms, Sequence)
     self.transforms = []
     for transform in transforms:
         if isinstance(transform, dict):
             transform = build_from_cfg(transform, PIPELINES)
             self.transforms.append(transform)
         elif callable(transform):
             self.transforms.append(transform)
         else:
             raise TypeError(f'transform must be callable or a dict, '
                             f'but got {type(transform)}')
Ejemplo n.º 6
0
def train(model, datasets, cfg, rank):
    data_loaders = []
    for ds in datasets:
        data_loaders.append(get_loader(ds, cfg, 'train'))

    # build runner for training
    if cfg.get('total_iters', None) is not None:
        runner = IterBasedRunner(model=model,
                                 optimizers_cfg=cfg.optimizers,
                                 work_dir=cfg.work_dir)
        total_iters_or_epochs = cfg.total_iters
    else:
        runner = EpochBasedRunner(model=model,
                                  optimizers_cfg=cfg.optimizers,
                                  work_dir=cfg.work_dir)
        assert cfg.get('total_epochs', None) is not None
        total_iters_or_epochs = cfg.total_epochs

    # resume and create optimizers
    if cfg.resume_from is not None:
        # 恢复之前的训练(包括模型参数和优化器)
        runner.resume(cfg.resume_from, cfg.get('resume_optim', False))
    elif cfg.load_from is not None:
        # 假装从头开始训练, rank0 进程加载参数,然后每个进程创建optim,调用optim init时,模型参数会自动同步
        runner.load_checkpoint(cfg.load_from, load_optim=False)
        runner.create_optimizers()
    else:
        # 不加载任何参数,每个进程直接创建optimizers
        runner.create_optimizers()

    # register hooks
    runner.register_training_hooks(lr_config=cfg.lr_config,
                                   checkpoint_config=cfg.checkpoint_config,
                                   log_config=cfg.log_config)

    # visual hook
    if cfg.get('visual_config', None) is not None:
        cfg.visual_config['output_dir'] = os.path.join(
            cfg.work_dir, cfg.visual_config['output_dir'])
        runner.register_hook(build_from_cfg(cfg.visual_config, HOOKS))

    # evaluation hook
    if cfg.get('evaluation', None) is not None:
        dataset = build_dataset(cfg.data.eval)
        save_path = os.path.join(cfg.work_dir, 'eval_visuals')
        log_path = cfg.work_dir
        runner.register_hook(
            EvalIterHook(get_loader(dataset, cfg, 'eval'),
                         save_path=save_path,
                         log_path=log_path,
                         **cfg.evaluation))

    runner.run(data_loaders, cfg.workflow, total_iters_or_epochs)
Ejemplo n.º 7
0
def build(cfg, registry, default_args=None):
    """Build module function.

    Args:
        cfg (dict): Configuration for building modules.
        registry (obj): ``registry`` object.
        default_args (dict, optional): Default arguments. Defaults to None.
    """
    if isinstance(cfg, list):
        raise NotImplementedError("list of cfg does not support now")
        # modules = [
        #     build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        # ]
        # return Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)
Ejemplo n.º 8
0
 def register_lr_hook(self, lr_config):
     if isinstance(lr_config, dict):
         assert 'policy' in lr_config
         policy_type = lr_config.pop('policy')
         # If the type of policy is all in lower case, e.g., 'cyclic',
         # then its first letter will be capitalized, e.g., to be 'Cyclic'.
         # This is for the convenient usage of Lr updater.
         # Since this is not applicable for `CosineAnealingLrUpdater`,
         # the string will not be changed if it contains capital letters.
         if policy_type == policy_type.lower():
             policy_type = policy_type.title()
         hook_type = policy_type + 'LrUpdaterHook'
         lr_config['type'] = hook_type
         hook = build_from_cfg(lr_config, HOOKS)
     else:
         hook = lr_config
     self.register_hook(hook)
Ejemplo n.º 9
0
def build_optimizer_constructor(cfg):
    return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
Ejemplo n.º 10
0
 def register_logger_hooks(self, log_config):
     log_interval = log_config['interval']
     for info in log_config['hooks']:
         logger_hook = build_from_cfg(
             info, HOOKS, default_args=dict(interval=log_interval))
         self.register_hook(logger_hook, priority='HIGH')