예제 #1
0
 def register_logger_hooks(self, log_config):
     if log_config is None:
         return
     log_interval = log_config['interval']
     for info in log_config['hooks']:
         logger_hook = cv_core.build_from_cfg(
             info, HOOKS, default_args=dict(interval=log_interval))
         self.register_hook(logger_hook, priority='VERY_LOW')
예제 #2
0
 def register_checkpoint_hook(self, checkpoint_config):
     if checkpoint_config is None:
         return
     if isinstance(checkpoint_config, dict):
         checkpoint_config.setdefault('type', 'CheckpointHook')
         hook = cv_core.build_from_cfg(checkpoint_config, HOOKS)
     else:
         hook = checkpoint_config
     self.register_hook(hook)
예제 #3
0
 def register_optimizer_hook(self, optimizer_config):
     if optimizer_config is None:
         return
     if isinstance(optimizer_config, dict):
         optimizer_config.setdefault('type', 'OptimizerHook')
         hook = cv_core.build_from_cfg(optimizer_config, HOOKS)
     else:
         hook = optimizer_config
     self.register_hook(hook)
예제 #4
0
def resize_test():
    #
    transform = dict(type='Resize', img_scale=(1333, 800), keep_ratio=True)
    transform = cv_core.build_from_cfg(transform, PIPELINES)
    input_shape = (60, 84, 3)
    img = np.zeros(input_shape, dtype=np.uint8)
    output = transform(dict(img=img))
    print(output['img_shape'])

    transform = dict(
        type='Resize',
        img_scale=(1333, 800),
        ratio_range=(0.9, 1.1),
        keep_ratio=True)
    transform = cv_core.build_from_cfg(transform, PIPELINES)
    input_shape = (60, 84, 3)
    img = np.zeros(input_shape, dtype=np.uint8)
    output = transform(dict(img=img))
    print(output['img_shape'])
예제 #5
0
def build_activation_layer(cfg):
    """Build activation layer.

    Args:
        cfg (dict): The activation layer config, which should contain:
            - type (str): Layer type.
            - layer args: Args needed to instantiate an activation layer.

    Returns:
        nn.Module: Created activation layer.
    """
    return build_from_cfg(cfg, ACTIVATION_LAYERS)
예제 #6
0
    def register_hook_from_cfg(self, hook_cfg):
        """Register a hook from its cfg.

        Args:
            hook_cfg (dict): Hook config. It should have at least keys 'type'
              and 'priority' indicating its type and priority.

        Notes:
            The specific hook class to register should not use 'type' and
            'priority' arguments during initialization.
        """
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        hook = cv_core.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)
예제 #7
0
 def register_lr_hook(self, lr_config):
     if isinstance(lr_config, dict):
         assert 'policy' in lr_config
         policy_type = lr_config.pop('policy')
         # If the type of policy is all in lower case, e.g., 'cyclic',
         # then its first letter will be capitalized, e.g., to be 'Cyclic'.
         # This is for the convenient usage of Lr updater.
         # Since this is not applicable for `
         # CosineAnnealingLrUpdater`,
         # the string will not be changed if it contains capital letters.
         if policy_type == policy_type.lower():
             policy_type = policy_type.title()
         hook_type = policy_type + 'LrUpdaterHook'
         lr_config['type'] = hook_type
         hook = cv_core.build_from_cfg(lr_config, HOOKS)
     else:
         hook = lr_config
     self.register_hook(hook)
예제 #8
0
def train_detector(model,
                   dataset,
                   cfg,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [
        build_dataloader(ds,
                         cfg.data.samples_per_gpu,
                         cfg.data.workers_per_gpu,
                         len(cfg.gpu_ids),
                         seed=cfg.seed) for ds in dataset
    ]

    # 作用很大,不仅仅是做dataparallel,还包括对DataContainer数据解码
    model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = EpochBasedRunner(model,
                              optimizer=optimizer,
                              work_dir=cfg.work_dir,
                              logger=logger,
                              meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    if 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    # register eval hooks
    if validate:
        # 验证评估模式采用的是hook模式注册
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        val_dataloader = build_dataloader(
            val_dataset,
            samples_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            shuffle=False)
        eval_cfg = cfg.get('evaluation', {})
        runner.register_hook(EvalHook(val_dataloader, **eval_cfg))

    # 用户自定义的hook列表
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                'Each item in custom_hooks expects dict type, but got ' \
                f'{type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)