def register_checkpoint_hook(self, checkpoint_config): if isinstance(checkpoint_config, dict): checkpoint_config.setdefault('type', 'CheckpointHook') hook = build_from_cfg(checkpoint_config, HOOKS) else: hook = checkpoint_config self.register_hook(hook)
def build_dataset(cfg, default_args=None): """Build a dataset from config dict. It supports a variety of dataset config. If ``cfg`` is a Sequential (list or dict), it will be a concatenated dataset of the datasets specified by the Sequential. If it is a ``RepeatDataset``, then it will repeat the dataset ``cfg['dataset']`` for ``cfg['times']`` times. If the ``ann_file`` of the dataset is a Sequential, then it will build a concatenated dataset with the same dataset type but different ``ann_file``. Args: cfg (dict): Config dict. It should at least contain the key "type". default_args (dict, optional): Default initialization arguments. Default: None. Returns: Dataset: The constructed dataset. """ if isinstance(cfg, (list, tuple)): raise NotImplementedError( "dose not support list(tuple) configs for dataset build now") elif cfg['type'] == 'RepeatDataset': dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args), cfg['times']) elif isinstance(cfg.get('ann_file'), (list, tuple)): raise NotImplementedError( "does not support list(tuple) ann_files for dataset build now") else: dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset
def register_optimizer_hook(self, optimizer_config): if optimizer_config is None: return if isinstance(optimizer_config, dict): optimizer_config.setdefault('type', 'OptimizerHook') hook = build_from_cfg(optimizer_config, HOOKS) else: hook = optimizer_config self.register_hook(hook)
def __call__(self, model): optimizer_cfg = self.optimizer_cfg.copy() # if no paramwise option is specified, just use the global setting logger = get_root_logger() param_nums = 0 for item in model.parameters(): param_nums += np.prod(np.array(item.shape)) logger.info("model: {} 's total parameter nums: {}".format(model.__class__.__name__, param_nums)) if not self.paramwise_cfg: optimizer_cfg['params'] = model.parameters() return build_from_cfg(optimizer_cfg, OPTIMIZERS) else: raise NotImplementedError("paramwise_cfg not implemented now") # set param-wise lr and weight decay recursively params = [] self.add_params(params, model) optimizer_cfg['params'] = params return build_from_cfg(optimizer_cfg, OPTIMIZERS)
def __init__(self, transforms): assert isinstance(transforms, Sequence) self.transforms = [] for transform in transforms: if isinstance(transform, dict): transform = build_from_cfg(transform, PIPELINES) self.transforms.append(transform) elif callable(transform): self.transforms.append(transform) else: raise TypeError(f'transform must be callable or a dict, ' f'but got {type(transform)}')
def train(model, datasets, cfg, rank): data_loaders = [] for ds in datasets: data_loaders.append(get_loader(ds, cfg, 'train')) # build runner for training if cfg.get('total_iters', None) is not None: runner = IterBasedRunner(model=model, optimizers_cfg=cfg.optimizers, work_dir=cfg.work_dir) total_iters_or_epochs = cfg.total_iters else: runner = EpochBasedRunner(model=model, optimizers_cfg=cfg.optimizers, work_dir=cfg.work_dir) assert cfg.get('total_epochs', None) is not None total_iters_or_epochs = cfg.total_epochs # resume and create optimizers if cfg.resume_from is not None: # 恢复之前的训练(包括模型参数和优化器) runner.resume(cfg.resume_from, cfg.get('resume_optim', False)) elif cfg.load_from is not None: # 假装从头开始训练, rank0 进程加载参数,然后每个进程创建optim,调用optim init时,模型参数会自动同步 runner.load_checkpoint(cfg.load_from, load_optim=False) runner.create_optimizers() else: # 不加载任何参数,每个进程直接创建optimizers runner.create_optimizers() # register hooks runner.register_training_hooks(lr_config=cfg.lr_config, checkpoint_config=cfg.checkpoint_config, log_config=cfg.log_config) # visual hook if cfg.get('visual_config', None) is not None: cfg.visual_config['output_dir'] = os.path.join( cfg.work_dir, cfg.visual_config['output_dir']) runner.register_hook(build_from_cfg(cfg.visual_config, HOOKS)) # evaluation hook if cfg.get('evaluation', None) is not None: dataset = build_dataset(cfg.data.eval) save_path = os.path.join(cfg.work_dir, 'eval_visuals') log_path = cfg.work_dir runner.register_hook( EvalIterHook(get_loader(dataset, cfg, 'eval'), save_path=save_path, log_path=log_path, **cfg.evaluation)) runner.run(data_loaders, cfg.workflow, total_iters_or_epochs)
def build(cfg, registry, default_args=None): """Build module function. Args: cfg (dict): Configuration for building modules. registry (obj): ``registry`` object. default_args (dict, optional): Default arguments. Defaults to None. """ if isinstance(cfg, list): raise NotImplementedError("list of cfg does not support now") # modules = [ # build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg # ] # return Sequential(*modules) else: return build_from_cfg(cfg, registry, default_args)
def register_lr_hook(self, lr_config): if isinstance(lr_config, dict): assert 'policy' in lr_config policy_type = lr_config.pop('policy') # If the type of policy is all in lower case, e.g., 'cyclic', # then its first letter will be capitalized, e.g., to be 'Cyclic'. # This is for the convenient usage of Lr updater. # Since this is not applicable for `CosineAnealingLrUpdater`, # the string will not be changed if it contains capital letters. if policy_type == policy_type.lower(): policy_type = policy_type.title() hook_type = policy_type + 'LrUpdaterHook' lr_config['type'] = hook_type hook = build_from_cfg(lr_config, HOOKS) else: hook = lr_config self.register_hook(hook)
def build_optimizer_constructor(cfg): return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
def register_logger_hooks(self, log_config): log_interval = log_config['interval'] for info in log_config['hooks']: logger_hook = build_from_cfg( info, HOOKS, default_args=dict(interval=log_interval)) self.register_hook(logger_hook, priority='HIGH')