Ejemplo n.º 1
0
def train_model(model,
                dataset,
                cfg,
                distributed=False,
                validate=False,
                timestamp=None,
                meta=None):
    """Train model entry function.

    Args:
        model (nn.Module): The model to be trained.
        dataset (Dataset): Train dataset.
        cfg (dict): The config dict for training.
        distributed (bool): Whether to use distributed training.
            Default: False.
        validate (bool): Whether to do evaluation. Default: False.
        timestamp (str | None): Local time for runner. Default: None.
        meta (dict | None): Meta dict to record some important information.
            Default: None
    """
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    # step 1: give default values and override (if exist) from cfg.data
    loader_cfg = {
        **dict(
            seed=cfg.get('seed'),
            drop_last=False,
            dist=distributed,
            num_gpus=len(cfg.gpu_ids)),
        **({} if torch.__version__ != 'parrots' else dict(
               prefetch_num=2,
               pin_memory=False,
           )),
        **dict((k, cfg.data[k]) for k in [
                   'samples_per_gpu',
                   'workers_per_gpu',
                   'shuffle',
                   'seed',
                   'drop_last',
                   'prefetch_num',
                   'pin_memory',
                   'persistent_workers',
               ] if k in cfg.data)
    }

    # step 2: cfg.data.train_dataloader has highest priority
    train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))

    data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

    # determine whether use adversarial training precess or not
    use_adverserial_train = cfg.get('use_adversarial_train', False)

    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', True)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel

        if use_adverserial_train:
            # Use DistributedDataParallelWrapper for adversarial training
            model = DistributedDataParallelWrapper(
                model,
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False,
                find_unused_parameters=find_unused_parameters)
        else:
            model = MMDistributedDataParallel(
                model.cuda(),
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False,
                find_unused_parameters=find_unused_parameters)
    else:
        model = MMDataParallel(
            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)

    # build runner
    optimizer = build_optimizers(model, cfg.optimizer)

    runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    if use_adverserial_train:
        # The optimizer step process is included in the train_step function
        # of the model, so the runner should NOT include optimizer hook.
        optimizer_config = None
    else:
        # fp16 setting
        fp16_cfg = cfg.get('fp16', None)
        if fp16_cfg is not None:
            optimizer_config = Fp16OptimizerHook(
                **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
        elif distributed and 'type' not in cfg.optimizer_config:
            optimizer_config = OptimizerHook(**cfg.optimizer_config)
        else:
            optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        eval_cfg = cfg.get('evaluation', {})
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        dataloader_setting = dict(
            samples_per_gpu=1,
            workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
            # cfg.gpus will be ignored if distributed
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            drop_last=False,
            shuffle=False)
        dataloader_setting = dict(dataloader_setting,
                                  **cfg.data.get('val_dataloader', {}))
        val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
Ejemplo n.º 2
0
def test_build_optimizers():
    base_lr = 0.0001
    base_wd = 0.0002
    momentum = 0.9

    # basic config with ExampleModel
    optimizer_cfg = dict(model1=dict(type='SGD',
                                     lr=base_lr,
                                     weight_decay=base_wd,
                                     momentum=momentum),
                         model2=dict(type='SGD',
                                     lr=base_lr,
                                     weight_decay=base_wd,
                                     momentum=momentum))
    model = ExampleModel()
    optimizers = build_optimizers(model, optimizer_cfg)
    param_dict = dict(model.named_parameters())
    assert isinstance(optimizers, dict)
    for i in range(2):
        optimizer = optimizers[f'model{i+1}']
        param_groups = optimizer.param_groups[0]
        assert isinstance(optimizer, torch.optim.SGD)
        assert optimizer.defaults['lr'] == base_lr
        assert optimizer.defaults['momentum'] == momentum
        assert optimizer.defaults['weight_decay'] == base_wd
        assert len(param_groups['params']) == 2
        assert torch.equal(param_groups['params'][0],
                           param_dict[f'model{i+1}.weight'])
        assert torch.equal(param_groups['params'][1],
                           param_dict[f'model{i+1}.bias'])

    # basic config with Parallel model
    model = torch.nn.DataParallel(ExampleModel())
    optimizers = build_optimizers(model, optimizer_cfg)
    param_dict = dict(model.named_parameters())
    assert isinstance(optimizers, dict)
    for i in range(2):
        optimizer = optimizers[f'model{i+1}']
        param_groups = optimizer.param_groups[0]
        assert isinstance(optimizer, torch.optim.SGD)
        assert optimizer.defaults['lr'] == base_lr
        assert optimizer.defaults['momentum'] == momentum
        assert optimizer.defaults['weight_decay'] == base_wd
        assert len(param_groups['params']) == 2
        assert torch.equal(param_groups['params'][0],
                           param_dict[f'module.model{i+1}.weight'])
        assert torch.equal(param_groups['params'][1],
                           param_dict[f'module.model{i+1}.bias'])

    # basic config with ExampleModel (one optimizer)
    optimizer_cfg = dict(type='SGD',
                         lr=base_lr,
                         weight_decay=base_wd,
                         momentum=momentum)
    model = ExampleModel()
    optimizer = build_optimizers(model, optimizer_cfg)
    param_dict = dict(model.named_parameters())
    assert isinstance(optimizers, dict)
    param_groups = optimizer.param_groups[0]
    assert isinstance(optimizer, torch.optim.SGD)
    assert optimizer.defaults['lr'] == base_lr
    assert optimizer.defaults['momentum'] == momentum
    assert optimizer.defaults['weight_decay'] == base_wd
    assert len(param_groups['params']) == 4
    assert torch.equal(param_groups['params'][0], param_dict['model1.weight'])
    assert torch.equal(param_groups['params'][1], param_dict['model1.bias'])
    assert torch.equal(param_groups['params'][2], param_dict['model2.weight'])
    assert torch.equal(param_groups['params'][3], param_dict['model2.bias'])

    # basic config with Parallel model (one optimizer)
    model = torch.nn.DataParallel(ExampleModel())
    optimizer = build_optimizers(model, optimizer_cfg)
    param_dict = dict(model.named_parameters())
    assert isinstance(optimizers, dict)
    param_groups = optimizer.param_groups[0]
    assert isinstance(optimizer, torch.optim.SGD)
    assert optimizer.defaults['lr'] == base_lr
    assert optimizer.defaults['momentum'] == momentum
    assert optimizer.defaults['weight_decay'] == base_wd
    assert len(param_groups['params']) == 4
    assert torch.equal(param_groups['params'][0],
                       param_dict['module.model1.weight'])
    assert torch.equal(param_groups['params'][1],
                       param_dict['module.model1.bias'])
    assert torch.equal(param_groups['params'][2],
                       param_dict['module.model2.weight'])
    assert torch.equal(param_groups['params'][3],
                       param_dict['module.model2.bias'])