Esempio n. 1
0
 def __init__(self, cfg):
     self.cfg = cfg
     self.multi_grid_cfg = cfg.get('multigrid', None)
     self.data_cfg = cfg.get('data', None)
     assert (self.multi_grid_cfg is not None and self.data_cfg is not None)
     self.logger = get_root_logger()
     self.logger.info(self.multi_grid_cfg)
def main():
    parser = ArgumentParser()
    parser.add_argument('config', help='Config file path')
    parser.add_argument('--load_from',
                        help='the checkpoint file to init weights from')
    parser.add_argument('--load2d_from',
                        help='the checkpoint file to init 2D weights from')
    parser.add_argument('--update_config',
                        nargs='+',
                        action=ExtendedDictAction,
                        help='arguments in dict')
    args = parser.parse_args()

    cfg = mmcv.Config.fromfile(args.config)
    if args.update_config is not None:
        cfg.merge_from_dict(args.update_config)
    cfg = update_config(cfg, args)

    net = build_recognizer(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
    net.eval()

    if cfg.load_from:
        logger = get_root_logger(log_level=cfg.log_level)
        load_checkpoint(net,
                        cfg.load_from,
                        strict=False,
                        logger=logger,
                        show_converted=True,
                        force_matching=True)

    conv_layers = collect_conv_layers(net)
    show_stat(conv_layers)
Esempio n. 3
0
    def __init__(self,
                 dataloader,
                 start=None,
                 interval=1,
                 save_best=True,
                 key_indicator='top1_acc',
                 rule=None,
                 **eval_kwargs):
        if not isinstance(dataloader, DataLoader):
            raise TypeError(f'dataloader must be a pytorch DataLoader, '
                            f'but got {type(dataloader)}')
        if not isinstance(save_best, bool):
            raise TypeError("'save_best' should be a boolean")

        if save_best and not key_indicator:
            raise ValueError('key_indicator should not be None, when '
                             'save_best is set to True.')
        if rule not in self.rule_map and rule is not None:
            raise KeyError(f'rule must be greater, less or None, '
                           f'but got {rule}.')

        if rule is None and save_best:
            if any(key in key_indicator for key in self.greater_keys):
                rule = 'greater'
            elif any(key in key_indicator for key in self.less_keys):
                rule = 'less'
            else:
                raise ValueError(
                    f'key_indicator must be in {self.greater_keys} '
                    f'or in {self.less_keys} when rule is None, '
                    f'but got {key_indicator}')

        if not interval > 0:
            raise ValueError(f'interval must be positive, but got {interval}')
        if start is not None and start < 0:
            warnings.warn(
                f'The evaluation start epoch {start} is smaller than 0, '
                f'use 0 instead', UserWarning)
            start = 0

        self.dataloader = dataloader
        self.interval = interval
        self.start = start
        self.eval_kwargs = eval_kwargs
        self.save_best = save_best
        self.key_indicator = key_indicator
        self.rule = rule

        self.logger = get_root_logger()

        if self.save_best:
            self.compare_func = self.rule_map[self.rule]
            self.best_score = self.init_value_map[self.rule]

        self.best_json = dict()
        self.initial_epoch_flag = True
Esempio n. 4
0
 def init_weights(self, pretrained=None):
     if isinstance(pretrained, str):
         logger = get_root_logger()
         load_checkpoint(self, pretrained, strict=False, logger=logger)
     elif pretrained is None:
         kaiming_init(self.st_feat_conv)
         kaiming_init(self.lt_feat_conv)
         for layer_name in self.non_local_layers:
             non_local_layer = getattr(self, layer_name)
             non_local_layer.init_weights(pretrained=pretrained)
     else:
         raise TypeError('pretrained must be a str or None')
Esempio n. 5
0
 def init_weights(self, pretrained=None):
     """Initiate the parameters either from existing checkpoint or from
     scratch."""
     if isinstance(pretrained, str):
         logger = get_root_logger()
         logger.info(f'load model from: {pretrained}')
         load_checkpoint(self, pretrained, strict=False, logger=logger)
     elif pretrained is None:
         for m in self.modules():
             if isinstance(m, nn.Conv3d):
                 kaiming_init(m)
             elif isinstance(m, _BatchNorm):
                 constant_init(m, 1)
         if self.zero_init_out_conv:
             constant_init(self.out_conv, 0, bias=0)
     else:
         raise TypeError('pretrained must be a str or None')
Esempio n. 6
0
def change_export_func_first_conv(model):
    """ To avoid saturation issue
    At the moment works only for mobilenet
    """
    def run_hacked_export_quantization(self, x):
        from nncf.quantization.layers import (
            ExportQuantizeToFakeQuantize, ExportQuantizeToONNXQuantDequant,
            QuantizerExportMode, get_scale_zp_from_input_low_input_high)
        from nncf.utils import no_jit_trace
        with no_jit_trace():
            input_range = abs(self.scale) + self.eps
            input_low = input_range * self.level_low / self.level_high
            input_high = input_range

            if self._export_mode == QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS:
                y_scale, y_zero_point = get_scale_zp_from_input_low_input_high(
                    self.level_low, self.level_high, input_low, input_high)

        if self._export_mode == QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS:
            return ExportQuantizeToONNXQuantDequant.apply(
                x, y_scale, y_zero_point)
        if self._export_mode == QuantizerExportMode.FAKE_QUANTIZE:
            x = x / 2.0
            return ExportQuantizeToFakeQuantize.apply(x, self.levels,
                                                      input_low, input_high,
                                                      input_low * 2,
                                                      input_high * 2)
        raise RuntimeError

    logger = get_root_logger()
    orig_model = model.get_nncf_wrapped_model()
    try:
        # pylint: disable=protected-access
        module_ = orig_model.backbone.features.init_block.conv.pre_ops._modules[
            '0']
    except (AttributeError, KeyError) as e:
        logger.info(
            f'Cannot change an export function for the first Conv due  {e}')
        return model
    module_.op.run_export_quantization = partial(
        run_hacked_export_quantization, module_.op)
    logger.info(
        'Change an export function for the first Conv to avoid saturation issue on AVX2, AVX512'
    )
    return model
Esempio n. 7
0
    def __init__(self,
                 dataloader,
                 interval=1,
                 gpu_collect=False,
                 save_best=False,
                 key_indicator=None,
                 rule=None,
                 **eval_kwargs):
        if not isinstance(dataloader, DataLoader):
            raise TypeError(f'dataloader must be a pytorch DataLoader, '
                            f'but got {type(dataloader)}')
        if save_best and not key_indicator:
            raise ValueError('key_indicator should not be None, when '
                             'save_best is set to True.')
        if rule not in self.rule_map and rule is not None:
            raise KeyError(f'rule must be greater, less or None, '
                           f'but got {rule}.')

        if rule is None and save_best:
            if any(key in key_indicator for key in self.greater_keys):
                rule = 'greater'
            elif any(key in key_indicator for key in self.less_keys):
                rule = 'less'
            else:
                raise ValueError(
                    f'key_indicator must be in {self.greater_keys} '
                    f'or in {self.less_keys} when rule is None, '
                    f'but got {key_indicator}')

        self.dataloader = dataloader
        self.interval = interval
        self.gpu_collect = gpu_collect
        self.eval_kwargs = eval_kwargs
        self.save_best = save_best
        self.key_indicator = key_indicator
        self.rule = rule

        self.logger = get_root_logger()

        if self.save_best:
            self.compare_func = self.rule_map[self.rule]
            self.best_score = self.init_value_map[self.rule]
            self.best_ckpt = None

        self.best_json = dict()
Esempio n. 8
0
def get_nncf_config_from_meta(path):
    """
    The function uses metadata stored in a checkpoint to restore the nncf
    part of the model config.
    """
    logger = get_root_logger()
    checkpoint = torch.load(path, map_location='cpu')
    meta = checkpoint.get('meta', {})

    nncf_enable_compression = meta.get('nncf_enable_compression', False)
    assert nncf_enable_compression, \
            'get_nncf_config_from_meta should be run for NNCF-compressed checkpoints only'

    config_text = meta['config']

    with tempfile.NamedTemporaryFile(prefix='config_',
                                     suffix='.py',
                                     mode='w',
                                     delete=False) as f_tmp:
        f_tmp.write(config_text)
        tmp_name = f_tmp.name
    cfg = mmcv.Config.fromfile(tmp_name)
    os.unlink(tmp_name)

    nncf_config = cfg.get('nncf_config')

    assert isinstance(
        nncf_config,
        dict), (f'Wrong nncf_config part of the config saved in the metainfo'
                f' of the snapshot {path}:'
                f' nncf_config={nncf_config}')

    nncf_config_part = {
        'nncf_config': nncf_config,
        'find_unused_parameters': True
    }
    if nncf_config_part['nncf_config'].get('log_dir'):
        log_dir = tempfile.mkdtemp(prefix='nncf_output_')
        nncf_config_part['nncf_config']['log_dir'] = log_dir

    logger.info(
        f'Read nncf config from meta nncf_config_part={nncf_config_part}')
    return nncf_config_part
Esempio n. 9
0
    def init_weights(self):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        if isinstance(self.pretrained, str):
            logger = get_root_logger()
            logger.info(f'load model from: {self.pretrained}')

            load_checkpoint(self, self.pretrained, strict=False, logger=logger)

        elif self.pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, _BatchNorm):
                    constant_init(m, 1)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck2dAudio):
                        constant_init(m.conv3.bn, 0)

        else:
            raise TypeError('pretrained must be a str or None')
def main():
    parser = argparse.ArgumentParser(description='Benchmark dataloading')
    parser.add_argument('config', help='train config file path')
    args = parser.parse_args()
    cfg = Config.fromfile(args.config)

    # init logger before other steps
    logger = get_root_logger()
    logger.info(f'MMAction2 Version: {__version__}')
    logger.info(f'Config: {cfg.text}')

    # create bench data list
    ann_file_bench = 'benchlist.txt'
    if not os.path.exists(ann_file_bench):
        with open(cfg.ann_file_train) as f:
            lines = f.readlines()[:256]
            with open(ann_file_bench, 'w') as f1:
                f1.writelines(lines)
    cfg.data.train.ann_file = ann_file_bench

    dataset = build_dataset(cfg.data.train)
    data_loader = build_dataloader(dataset,
                                   videos_per_gpu=cfg.data.videos_per_gpu,
                                   workers_per_gpu=0,
                                   num_gpus=1,
                                   dist=False)

    # Start progress bar after first 5 batches
    prog_bar = mmcv.ProgressBar(len(dataset) - 5 * cfg.data.videos_per_gpu,
                                start=False)
    for i, data in enumerate(data_loader):
        if i == 5:
            prog_bar.start()
        for img in data['imgs']:
            if i < 5:
                continue
            prog_bar.update()
Esempio n. 11
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    cfg.merge_from_dict(args.cfg_options)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority:
    # CLI > config file > default (base filename)
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # The flag is used to determine whether it is omnisource training
    cfg.setdefault('omnisource', False)

    # The flag is used to register module's hooks
    cfg.setdefault('module_hooks', [])

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config: {cfg.text}')

    # set random seeds
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    meta['config_name'] = osp.basename(args.config)
    meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\'))

    model = build_model(cfg.model,
                        train_cfg=cfg.get('train_cfg'),
                        test_cfg=cfg.get('test_cfg'))

    register_module_hooks(model.backbone, cfg.module_hooks)

    if cfg.omnisource:
        # If omnisource flag is set, cfg.data.train should be a list
        assert type(cfg.data.train) is list
        datasets = [build_dataset(dataset) for dataset in cfg.data.train]
    else:
        datasets = [build_dataset(cfg.data.train)]

    if len(cfg.workflow) == 2:
        # For simplicity, omnisource is not compatiable with val workflow,
        # we recommend you to use `--validate`
        assert not cfg.omnisource
        if args.validate:
            warnings.warn('val workflow is duplicated with `--validate`, '
                          'it is recommended to use `--validate`. see '
                          'https://github.com/open-mmlab/mmaction2/pull/123')
        val_dataset = copy.deepcopy(cfg.data.val)
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmaction version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(mmaction_version=__version__ +
                                          get_git_hash(digits=7),
                                          config=cfg.text)

    train_model(model,
                datasets,
                cfg,
                distributed=distributed,
                validate=args.validate,
                timestamp=timestamp,
                meta=meta)
Esempio n. 12
0
def main():
    # parse arguments
    args = parse_args()

    # load config
    cfg = Config.fromfile(args.config)
    if args.update_config is not None:
        cfg.merge_from_dict(args.update_config)
    cfg = update_config(cfg, args)
    cfg = propagate_root_dir(cfg, args.data_dir)

    # init distributed env first, since logger depends on the dist info.
    distributed = args.launcher != 'none'
    if distributed:
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))

    # init logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()

    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
    meta['env_info'] = env_info

    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config: {cfg.text}')

    if cfg.get('nncf_config'):
        check_nncf_is_enabled()
        logger.info('NNCF config: {}'.format(cfg.nncf_config))
        meta.update(get_nncf_metadata())

    # set random seeds
    cfg.seed = args.seed
    meta['seed'] = args.seed
    if cfg.get('seed'):
        logger.info(f'Set random seed to {cfg.seed}, deterministic: {args.deterministic}')
        set_random_seed(cfg.seed, deterministic=args.deterministic)

    # build datasets
    datasets = [build_dataset(cfg.data, 'train', dict(logger=logger))]
    logger.info(f'Train datasets:\n{str(datasets[0])}')

    if len(cfg.workflow) == 2:
        if not args.no_validate:
            warnings.warn('val workflow is duplicated with `--validate`, '
                          'it is recommended to use `--validate`. see '
                          'https://github.com/open-mmlab/mmaction2/pull/123')
        datasets.append(build_dataset(copy.deepcopy(cfg.data), 'val', dict(logger=logger)))
        logger.info(f'Val datasets:\n{str(datasets[-1])}')

    # filter dataset labels
    if cfg.get('classes'):
        datasets = [dataset.filter(cfg.classes) for dataset in datasets]

    # build model
    model = build_model(
        cfg.model,
        train_cfg=cfg.train_cfg,
        test_cfg=cfg.test_cfg,
        class_sizes=datasets[0].class_sizes,
        class_maps=datasets[0].class_maps
    )

    # define ignore layers
    ignore_prefixes = []
    if hasattr(cfg, 'reset_layer_prefixes') and isinstance(cfg.reset_layer_prefixes, (list, tuple)):
        ignore_prefixes += cfg.reset_layer_prefixes
    ignore_suffixes = ['num_batches_tracked']
    if hasattr(cfg, 'reset_layer_suffixes') and isinstance(cfg.reset_layer_suffixes, (list, tuple)):
        ignore_suffixes += cfg.reset_layer_suffixes

    # train model
    train_model(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        meta=meta,
        ignore_prefixes=tuple(ignore_prefixes),
        ignore_suffixes=tuple(ignore_suffixes)
    )
Esempio n. 13
0
def wrap_nncf_model(model,
                    cfg,
                    data_loader_for_init=None,
                    get_fake_input_func=None,
                    export=False):
    """
    The function wraps mmaction model by NNCF
    Note that the parameter `get_fake_input_func` should be the function `get_fake_input`
    -- cannot import this function here explicitly
    """

    check_nncf_is_enabled()

    from nncf.config import NNCFConfig
    from nncf.torch import (create_compressed_model,
                            register_default_init_args)
    from nncf.torch.dynamic_graph.io_handling import nncf_model_input
    from nncf.torch.dynamic_graph.trace_tensor import TracedTensor
    from nncf.torch.initialization import DefaultInitializingDataLoader

    class MMInitializeDataLoader(DefaultInitializingDataLoader):
        def get_inputs(self, dataloader_output):
            return (), dataloader_output

    pathlib.Path(cfg.work_dir).mkdir(parents=True, exist_ok=True)
    nncf_config = NNCFConfig(cfg.nncf_config)
    logger = get_root_logger(cfg.log_level)

    if data_loader_for_init:
        wrapped_loader = MMInitializeDataLoader(data_loader_for_init)
        nncf_config = register_default_init_args(
            nncf_config,
            wrapped_loader,
            device=next(model.parameters()).device)

    if cfg.get('resume_from'):
        checkpoint_path = cfg.get('resume_from')
        assert is_checkpoint_nncf(checkpoint_path), (
            'It is possible to resume training with NNCF compression from NNCF checkpoints only. '
            'Use "load_from" with non-compressed model for further compression by NNCF.'
        )
    elif cfg.get('load_from'):
        checkpoint_path = cfg.get('load_from')
        if not is_checkpoint_nncf(checkpoint_path):
            checkpoint_path = None
            logger.info('Received non-NNCF checkpoint to start training '
                        '-- initialization of NNCF fields will be done')
    else:
        checkpoint_path = None

    if not data_loader_for_init and not checkpoint_path:
        raise RuntimeError('Either data_loader_for_init or NNCF pre-trained '
                           'model checkpoint should be set')

    if checkpoint_path:
        logger.info(f'Loading NNCF checkpoint from {checkpoint_path}')
        logger.info(
            'Please, note that this first loading is made before addition of '
            'NNCF FakeQuantize nodes to the model, so there may be some '
            'warnings on unexpected keys')
        resuming_state_dict = load_checkpoint(model, checkpoint_path)
        logger.info(f'Loaded NNCF checkpoint from {checkpoint_path}')
    else:
        resuming_state_dict = None

    if "nncf_compress_postprocessing" in cfg:
        # NB: This parameter is used to choose if we should try to make NNCF compression
        #     for a whole model graph including postprocessing (`nncf_compress_postprocessing=True`),
        #     or make NNCF compression of the part of the model without postprocessing
        #     (`nncf_compress_postprocessing=False`).
        #     Our primary goal is to make NNCF compression of such big part of the model as
        #     possible, so `nncf_compress_postprocessing=True` is our primary choice, whereas
        #     `nncf_compress_postprocessing=False` is our fallback decision.
        #     When we manage to enable NNCF compression for sufficiently many models,
        #     we should keep one choice only.
        nncf_compress_postprocessing = cfg.get('nncf_compress_postprocessing')
        logger.debug('set should_compress_postprocessing='
                     f'{nncf_compress_postprocessing}')
    else:
        nncf_compress_postprocessing = True

    def _get_fake_data_for_forward(cfg, nncf_config, get_fake_input_func):
        input_size = nncf_config.get("input_info").get('sample_size')
        assert get_fake_input_func is not None
        assert len(input_size) == 4 and input_size[0] == 1
        H, W, C = input_size[2], input_size[3], input_size[1]
        device = next(model.parameters()).device
        with no_nncf_trace():
            return get_fake_input_func(cfg,
                                       orig_img_shape=tuple([H, W, C]),
                                       device=device)

    def dummy_forward(model):
        fake_data = _get_fake_data_for_forward(cfg, nncf_config,
                                               get_fake_input_func)
        img = fake_data["imgs"]
        img = nncf_model_input(img)
        if export:
            img, _, _ = model.reshape_input(imgs=img)
            model(imgs=img)
        else:
            model(imgs=img, return_loss=False)

    def wrap_inputs(args, kwargs):
        # during dummy_forward
        if not len(kwargs):
            if not isinstance(args[0][0], TracedTensor):
                args[0][0] = nncf_model_input(args[0][0])
            return args, kwargs

        # during building original graph
        if not kwargs.get('return_loss') and kwargs.get('forward_export'):
            return args, kwargs

        # during model's forward
        assert 'imgs' in kwargs, 'During model forward imgs must be in kwargs'
        img = kwargs['imgs']
        if isinstance(img, list):
            assert len(img) == 1, 'Input list must have a length 1'
            assert torch.is_tensor(
                img[0]), 'Input for a model must be a tensor'
            if not isinstance(img[0], TracedTensor):
                img[0] = nncf_model_input(img[0])
        else:
            assert torch.is_tensor(img), 'Input for a model must be a tensor'
            if not isinstance(img, TracedTensor):
                img = nncf_model_input(img)
        kwargs['imgs'] = img
        return args, kwargs

    model.dummy_forward_fn = dummy_forward

    if 'log_dir' in nncf_config:
        os.makedirs(nncf_config['log_dir'], exist_ok=True)
    compression_ctrl, model = create_compressed_model(
        model,
        nncf_config,
        dummy_forward_fn=dummy_forward,
        wrap_inputs_fn=wrap_inputs,
        compression_state=resuming_state_dict)

    return compression_ctrl, model
Esempio n. 14
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority:
    # CLI > config file > default (base filename)
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config: {cfg.text}')

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}, deterministic: {}'.format(
            args.seed, args.deterministic))
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed

    model = build_model(cfg.model,
                        train_cfg=cfg.train_cfg,
                        test_cfg=cfg.test_cfg)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmaction version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(mmaction_version=__version__,
                                          config=cfg.text)

    train_model(model,
                datasets,
                cfg,
                distributed=distributed,
                validate=args.validate,
                timestamp=timestamp,
                meta=meta)