Exemplo n.º 1
0
def create(model_name: str = None,
           model: Union[str, Model] = None,
           folder_path: str = None,
           dataset_name: str = None,
           dataset: Union[str, Dataset] = None,
           config: Config = config,
           class_dict: dict[str, type[Model]] = {},
           **kwargs) -> Model:
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    model_name = get_name(name=model_name,
                          module=model,
                          arg_list=['-m', '--model'])
    if dataset_name is None:
        dataset_name = config.get_full_config()['dataset']['default_dataset']
    result = config.get_config(
        dataset_name=dataset_name)['model']._update(kwargs)
    model_name = model_name if model_name is not None else result[
        'default_model']

    ModelType: type[Model] = class_dict[model_name]
    if folder_path is None and isinstance(dataset, Dataset):
        folder_path = os.path.join(result['model_dir'], dataset.data_type,
                                   dataset.name)
    return ModelType(name=model_name,
                     dataset=dataset,
                     folder_path=folder_path,
                     **result)
Exemplo n.º 2
0
def create(dataset_name: str = None, dataset: Dataset = None, model: Model = None,
           tensorboard: bool = None, config: Config = config, **kwargs) -> tuple[Optimizer, _LRScheduler, dict]:
    assert isinstance(model, Model)
    dataset_name = get_name(name=dataset_name, module=dataset, arg_list=['-d', '--dataset'])
    result = config.get_config(dataset_name=dataset_name)['trainer']._update(kwargs)

    optim_keys = model.define_optimizer.__code__.co_varnames
    train_keys = model._train.__code__.co_varnames
    writer_keys = SummaryWriter.__init__.__code__.co_varnames   # log_dir, flush_secs, ...
    optim_args = {}
    train_args = {}
    writer_args = {}
    for key, value in result.items():
        if key in optim_keys:
            _dict = optim_args
        elif key in train_keys:
            _dict = train_args
        elif key in writer_keys:
            _dict = writer_args
        else:
            continue
        _dict[key] = value
    optimizer, lr_scheduler = model.define_optimizer(**optim_args)
    writer = SummaryWriter(**writer_args) if tensorboard else None
    return Trainer(optim_args=optim_args, train_args=train_args,
                   optimizer=optimizer, lr_scheduler=lr_scheduler,
                   writer=writer)
Exemplo n.º 3
0
def add_argument(parser: argparse.ArgumentParser,
                 model_name: str = None,
                 model: Union[str, Model] = None,
                 config: Config = config,
                 class_dict: dict[str, type[Model]] = {}):
    dataset_name = get_name(arg_list=['-d', '--dataset'])
    if dataset_name is None:
        dataset_name = config.get_full_config()['dataset']['default_dataset']
    model_name = get_name(name=model_name,
                          module=model,
                          arg_list=['-m', '--model'])
    if model_name is None:
        model_name = config.get_config(
            dataset_name=dataset_name)['model']['default_model']
    model_name = get_model_class(model_name, class_dict=class_dict)

    group = parser.add_argument_group('{yellow}model{reset}'.format(**ansi),
                                      description=model_name)
    model_class_name = get_model_class(model_name, class_dict=class_dict)
    try:
        ModelType = class_dict[model_class_name]
    except KeyError as e:
        print(f'{model_class_name} not in \n{list(class_dict.keys())}')
        raise e
    return ModelType.add_argument(group)
Exemplo n.º 4
0
def create(dataset_name: str = None,
           dataset: Dataset = None,
           model: Model = None,
           config: Config = config,
           **kwargs) -> tuple[Optimizer, _LRScheduler, dict]:
    assert isinstance(model, Model)
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    result = config.get_config(
        dataset_name=dataset_name)['trainer']._update(kwargs)

    func_keys = model.define_optimizer.__code__.co_varnames
    train_keys = model._train.__code__.co_varnames
    optim_args = {}
    train_args = {}
    for key, value in result.items():
        if key in func_keys:
            _dict = optim_args
        elif key in train_keys:
            _dict = train_args
        else:
            continue  # raise KeyError(key)
        _dict[key] = value
    optimizer, lr_scheduler = model.define_optimizer(**optim_args)
    return Trainer(optim_args=optim_args,
                   train_args=train_args,
                   optimizer=optimizer,
                   lr_scheduler=lr_scheduler)
Exemplo n.º 5
0
def create(model_name: str = None,
           model: Union[str, Model] = None,
           dataset_name: str = None,
           dataset: Union[str, Dataset] = None,
           folder_path: str = None,
           config: Config = config,
           class_dict: dict[str, type[Model]] = {},
           **kwargs) -> Model:
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    model_name = get_name(name=model_name,
                          module=model,
                          arg_list=['-m', '--model'])
    if dataset_name is None:
        dataset_name = config.get_full_config()['dataset']['default_dataset']
    if model_name is None:
        model_name = config.get_config(
            dataset_name=dataset_name)['model']['default_model']
    result = config.get_config(
        dataset_name=dataset_name)['model'].update(kwargs)
    model_name = model_name if model_name is not None else result[
        'default_model']

    name_list = [
        name
        for sub_list in get_available_models(class_dict=class_dict).values()
        for name in sub_list
    ]
    name_list = sorted(name_list)
    assert model_name in name_list, f'{model_name} not in \n{name_list}'
    model_class_name = get_model_class(model_name, class_dict=class_dict)
    try:
        ModelType = class_dict[model_class_name]
    except KeyError as e:
        print(f'{model_class_name} not in \n{list(class_dict.keys())}')
        raise e
    if folder_path is None and isinstance(dataset, Dataset):
        folder_path = os.path.join(result['model_dir'], dataset.data_type,
                                   dataset.name)
    return ModelType(name=model_name,
                     dataset=dataset,
                     folder_path=folder_path,
                     **result)
Exemplo n.º 6
0
def create(dataset_name: str = None, dataset: str = None, folder_path: str = None,
           config: Config = config, class_dict: dict[str, type[Dataset]] = {}, **kwargs) -> Dataset:
    dataset_name = get_name(name=dataset_name, module=dataset, arg_list=['-d', '--dataset'])
    dataset_name = dataset_name if dataset_name is not None else config.get_full_config()['dataset']['default_dataset']
    result = config.get_config(dataset_name=dataset_name)['dataset']._update(kwargs)

    DatasetType = class_dict[dataset_name]
    folder_path = folder_path if folder_path is not None else \
        os.path.join(result['data_dir'], DatasetType.data_type, DatasetType.name)     # TODO: Linting problem
    return DatasetType(folder_path=folder_path, **result)
Exemplo n.º 7
0
def create(config_path: str = None,
           dataset_name: str = None,
           dataset: str = None,
           seed: int = None,
           benchmark: bool = None,
           config: Config = config,
           cache_threshold: float = None,
           verbose: int = None,
           color: bool = None,
           tqdm: bool = None,
           **kwargs) -> Env:
    other_kwargs = {
        'cache_threshold': cache_threshold,
        'verbose': verbose,
        'color': color,
        'tqdm': tqdm
    }
    config.update_cmd(config_path)
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    dataset_name = dataset_name if dataset_name is not None else config.get_full_config(
    )['dataset']['default_dataset']
    result = config.get_config(
        dataset_name=dataset_name)['env']._update(other_kwargs)
    env.update(config_path=config_path, **result)
    if seed is None and 'seed' in env.keys():
        seed = env['seed']
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    num_gpus: int = torch.cuda.device_count()
    device = result['device']
    if device == 'none':
        device = None
    else:
        if device is None or device == 'auto':
            device = 'cuda' if num_gpus else 'cpu'
        if isinstance(device, str):
            device = torch.device(device)
        if device.type == 'cpu':
            num_gpus = 0
        if device.index is not None and torch.cuda.is_available():
            num_gpus = 1
    if benchmark is None and 'benchmark' in env.keys():
        benchmark = env['benchmark']
    if benchmark:
        torch.backends.cudnn.benchmark = benchmark
    env.update(seed=seed,
               device=device,
               benchmark=benchmark,
               num_gpus=num_gpus)
    return env
Exemplo n.º 8
0
def create(attack_name: str = None,
           attack: Union[str, Attack] = None,
           folder_path: str = None,
           dataset_name: str = None,
           dataset: Union[str, Dataset] = None,
           model_name: str = None,
           model: Union[str, Model] = None,
           config: Config = config,
           class_dict: dict[str, type[Attack]] = {},
           **kwargs):
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    model_name = get_name(name=model_name,
                          module=model,
                          arg_list=['-m', '--model'])
    attack_name = get_name(name=attack_name,
                           module=attack,
                           arg_list=['--attack'])
    if dataset_name is None:
        dataset_name = config.get_full_config()['dataset']['default_dataset']
    general_config = config.get_config(dataset_name=dataset_name)['attack']
    specific_config = config.get_config(dataset_name=dataset_name)[attack_name]
    result = general_config.update(specific_config).update(kwargs)
    try:
        AttackType = class_dict[attack_name]
    except KeyError as e:
        print(f'{attack_name} not in \n{list(class_dict.keys())}')
        raise e
    if folder_path is None:
        folder_path = result['attack_dir']
        if isinstance(dataset, Dataset):
            folder_path = os.path.join(folder_path, dataset.data_type,
                                       dataset.name)
        if model_name is not None:
            folder_path = os.path.join(folder_path, model_name)
        folder_path = os.path.join(folder_path, AttackType.name)
    return AttackType(name=attack_name,
                      dataset=dataset,
                      model=model,
                      folder_path=folder_path,
                      **result)
Exemplo n.º 9
0
def create(defense_name: str = None,
           defense: Union[str, Defense] = None,
           folder_path: str = None,
           dataset_name: str = None,
           dataset: Union[str, Dataset] = None,
           model_name: str = None,
           model: Union[str, Model] = None,
           config: Config = config,
           class_dict: dict[str, type[Defense]] = {},
           **kwargs) -> Defense:
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    model_name = get_name(name=model_name,
                          module=model,
                          arg_list=['-m', '--model'])
    defense_name = get_name(name=defense_name,
                            module=defense,
                            arg_list=['--defense'])
    if dataset_name is None:
        dataset_name = config.get_full_config()['dataset']['default_dataset']
    general_config = config.get_config(dataset_name=dataset_name)['defense']
    specific_config = config.get_config(
        dataset_name=dataset_name)[defense_name]
    result = general_config._update(specific_config)._update(
        kwargs)  # TODO: linting issues

    DefenseType: type[Defense] = class_dict[defense_name]
    if folder_path is None:
        folder_path = result['defense_dir']
        if isinstance(dataset, Dataset):
            folder_path = os.path.join(folder_path, dataset.data_type,
                                       dataset.name)
        if model_name is not None:
            folder_path = os.path.join(folder_path, model_name)
        folder_path = os.path.join(folder_path, DefenseType.name)
    return DefenseType(name=defense_name,
                       dataset=dataset,
                       model=model,
                       folder_path=folder_path,
                       **result)
Exemplo n.º 10
0
def add_argument(parser: argparse.ArgumentParser, model_name: str = None, model: Union[str, Model] = None,
                 config: Config = config, class_dict: dict[str, type[Model]] = None) -> argparse._ArgumentGroup:
    dataset_name = get_name(arg_list=['-d', '--dataset'])
    if dataset_name is None:
        dataset_name = config.get_full_config()['dataset']['default_dataset']
    model_name = get_name(name=model_name, module=model, arg_list=['-m', '--model'])
    if model_name is None:
        model_name = config.get_config(dataset_name=dataset_name)['model']['default_model']

    group = parser.add_argument_group('{yellow}model{reset}'.format(**ansi), description=model_name)
    ModelType = class_dict[model_name]
    return ModelType.add_argument(group)     # TODO: Linting problem
Exemplo n.º 11
0
def create(dataset_name: str = None,
           dataset: Dataset = None,
           model: Model = None,
           ClassType: type[Trainer] = Trainer,
           tensorboard: bool = None,
           config: Config = config,
           **kwargs):
    assert isinstance(model, Model)
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    result = config.get_config(
        dataset_name=dataset_name)['trainer'].update(kwargs)

    optim_keys = model.define_optimizer.__code__.co_varnames
    train_keys = model._train.__code__.co_varnames
    optim_args: dict[str, Any] = {}
    train_args: dict[str, Any] = {}
    for key, value in result.items():
        if key in optim_keys:
            _dict = optim_args
        elif key in train_keys:
            _dict = train_args
        else:
            continue
        _dict[key] = value
    optimizer, lr_scheduler = model.define_optimizer(T_max=result['epoch'],
                                                     **optim_args)

    writer = None
    writer_args: dict[str, Any] = {}
    if tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        writer_keys = SummaryWriter.__init__.__code__.co_varnames  # log_dir, flush_secs, ...
        for key, value in result.items():
            if key in writer_keys:
                writer_args[key] = value
        writer = SummaryWriter(**writer_args)
    return ClassType(optim_args=optim_args,
                     train_args=train_args,
                     writer_args=writer_args,
                     optimizer=optimizer,
                     lr_scheduler=lr_scheduler,
                     writer=writer)
Exemplo n.º 12
0
def create(dataset_name: str = None, dataset: str = None,
           config: Config = config,
           class_dict: dict[str, type[Dataset]] = {},
           **kwargs) -> Dataset:
    r"""
    | Create a dataset instance.
    | For arguments not included in :attr:`kwargs`,
      use the default values in :attr:`config`.
    | The default value of :attr:`folder_path` is
      ``'{data_dir}/{data_type}/{name}'``.
    | For dataset implementation, see :class:`Dataset`.

    Args:
        dataset_name (str): The dataset name.
        dataset (str): The alias of `dataset_name`.
        config (Config): The default parameter config.
        class_dict (dict[str, type[Dataset]]):
            Map from dataset name to dataset class.
            Defaults to ``{}``.
        **kwargs: Keyword arguments
            passed to dataset init method.

    Returns:
        Dataset: Dataset instance.
    """
    dataset_name = get_name(
        name=dataset_name, module=dataset, arg_list=['-d', '--dataset'])
    dataset_name = dataset_name if dataset_name is not None \
        else config.full_config['dataset']['default_dataset']
    result = config.get_config(dataset_name=dataset_name)[
        'dataset'].update(kwargs)
    try:
        DatasetType = class_dict[dataset_name]
    except KeyError:
        print(f'{dataset_name} not in \n{list(class_dict.keys())}')
        raise
    if 'folder_path' not in result.keys():
        result['folder_path'] = os.path.join(result['data_dir'],
                                             DatasetType.data_type,
                                             DatasetType.name)
    return DatasetType(**result)
Exemplo n.º 13
0
def create(dataset_name: str = None,
           dataset: str = None,
           folder_path: str = None,
           config: Config = config,
           class_dict: dict[str, type[Dataset]] = {},
           **kwargs):
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    dataset_name = dataset_name if dataset_name is not None else config.get_full_config(
    )['dataset']['default_dataset']
    result = config.get_config(
        dataset_name=dataset_name)['dataset'].update(kwargs)
    try:
        DatasetType = class_dict[dataset_name]
    except KeyError as e:
        print(f'{dataset_name} not in \n{list(class_dict.keys())}')
        raise e
    folder_path = folder_path if folder_path is not None \
        else os.path.join(result['data_dir'], DatasetType.data_type, DatasetType.name)
    return DatasetType(folder_path=folder_path, **result)
Exemplo n.º 14
0
def create(attack_name: str = None,
           attack: str | Attack = None,
           dataset_name: str = None,
           dataset: str | Dataset = None,
           model_name: str = None,
           model: str | Model = None,
           config: Config = config,
           class_dict: dict[str, type[Attack]] = {},
           **kwargs) -> Attack:
    r"""
    | Create an attack instance.
    | For arguments not included in :attr:`kwargs`,
      use the default values in :attr:`config`.
    | The default value of :attr:`folder_path` is
      ``'{attack_dir}/{dataset.data_type}/{dataset.name}/{model.name}/{attack.name}'``.
    | For attack implementation, see :class:`Attack`.

    Args:
        attack_name (str): The attack name.
        attack (str | Attack): The attack instance or attack name
            (as the alias of `attack_name`).
        dataset_name (str): The dataset name.
        dataset (str | Dataset):
            Dataset instance or dataset name
            (as the alias of `dataset_name`).
        model_name (str): The model name.
        model (str | Model): The model instance or model name
            (as the alias of `model_name`).
        config (Config): The default parameter config.
        class_dict (dict[str, type[Attack]]):
            Map from attack name to attack class.
            Defaults to ``{}``.
        **kwargs: The keyword arguments
            passed to attack init method.

    Returns:
        Attack: The attack instance.
    """
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    model_name = get_name(name=model_name,
                          module=model,
                          arg_list=['-m', '--model'])
    attack_name = get_name(name=attack_name,
                           module=attack,
                           arg_list=['--attack'])
    if dataset_name is None:
        dataset_name = config.full_config['dataset']['default_dataset']
    general_config = config.get_config(dataset_name=dataset_name)['attack']
    specific_config = config.get_config(dataset_name=dataset_name)[attack_name]
    result = general_config.update(specific_config).update(kwargs)
    try:
        AttackType = class_dict[attack_name]
    except KeyError:
        print(f'{attack_name} not in \n{list(class_dict.keys())}')
        raise
    if 'folder_path' not in result.keys():
        folder_path = result['attack_dir']
        if isinstance(dataset, Dataset):
            folder_path = os.path.join(folder_path, dataset.data_type,
                                       dataset.name)
        if model_name is not None:
            folder_path = os.path.join(folder_path, model_name)
        folder_path = os.path.join(folder_path, AttackType.name)
        result['folder_path'] = folder_path
    return AttackType(name=attack_name, dataset=dataset, model=model, **result)
Exemplo n.º 15
0
def create(defense_name: None | str = None,
           defense: None | str | Defense = None,
           folder_path: None | str = None,
           dataset_name: str = None,
           dataset: None | str | Dataset = None,
           model_name: str = None,
           model: None | str | Model = None,
           config: Config = config,
           class_dict: dict[str, type[Defense]] = {},
           **kwargs):
    r"""
    | Create a defense instance.
    | For arguments not included in :attr:`kwargs`,
      use the default values in :attr:`config`.
    | The default value of :attr:`folder_path` is
      ``'{defense_dir}/{dataset.data_type}/{dataset.name}/{model.name}/{defense.name}'``.
    | For defense implementation, see :class:`Defense`.

    Args:
        defense_name (str): The defense name.
        defense (str | Defense): The defense instance or defense name
            (as the alias of `defense_name`).
        dataset_name (str): The dataset name.
        dataset (str | trojanzoo.datasets.Dataset):
            Dataset instance or dataset name
            (as the alias of `dataset_name`).
        model_name (str): The model name.
        model (str | Model): The model instance or model name
            (as the alias of `model_name`).
        config (Config): The default parameter config.
        class_dict (dict[str, type[Defense]]):
            Map from defense name to defense class.
            Defaults to ``{}``.
        **kwargs: The keyword arguments
            passed to defense init method.

    Returns:
        Defense: The defense instance.
    """
    dataset_name = get_name(name=dataset_name,
                            module=dataset,
                            arg_list=['-d', '--dataset'])
    model_name = get_name(name=model_name,
                          module=model,
                          arg_list=['-m', '--model'])
    defense_name = get_name(name=defense_name,
                            module=defense,
                            arg_list=['--defense'])
    if dataset_name is None:
        dataset_name = config.full_config['dataset']['default_dataset']
    general_config = config.get_config(dataset_name=dataset_name)['defense']
    specific_config = config.get_config(
        dataset_name=dataset_name)[defense_name]
    result = general_config.update(specific_config).update(
        kwargs)  # TODO: linting issues
    try:
        DefenseType = class_dict[defense_name]
    except KeyError:
        print(f'{defense_name} not in \n{list(class_dict.keys())}')
        raise
    if 'folder_path' not in result.keys():
        folder_path = result['defense_dir']
        if isinstance(dataset, Dataset):
            folder_path = os.path.join(folder_path, dataset.data_type,
                                       dataset.name)
        if model_name is not None:
            folder_path = os.path.join(folder_path, model_name)
        folder_path = os.path.join(folder_path, DefenseType.name)
        result['folder_path'] = folder_path
    return DefenseType(name=defense_name,
                       dataset=dataset,
                       model=model,
                       **result)
Exemplo n.º 16
0
def create(cmd_config_path: str = None, dataset_name: str = None, dataset: str = None,
           seed: int = None, data_seed: int = None, cudnn_benchmark: bool = None,
           config: Config = config,
           cache_threshold: float = None, verbose: int = 0,
           color: bool = None, device: str | int | torch.device = None, tqdm: bool = None,
           **kwargs) -> Env:
    r"""
    | Load :attr:`env` values from config and command line.

    Args:
        dataset_name (str): The dataset name.
        dataset (str | trojanzoo.datasets.Dataset):
            Dataset instance
            (required for :attr:`model_ema`)
            or dataset name
            (as the alias of `dataset_name`).
        model (trojanzoo.models.Model): Model instance.
        config (Config): The default parameter config.
        **kwargs: The keyword arguments in keys of
            ``['optim_args', 'train_args', 'writer_args']``.

    Returns:
        Env: The :attr:`env` instance.
    """
    other_kwargs = {'data_seed': data_seed, 'cache_threshold': cache_threshold,
                    'verbose': verbose, 'color': color, 'device': device, 'tqdm': tqdm}
    config.cmd_config_path = cmd_config_path
    dataset_name = get_name(
        name=dataset_name, module=dataset, arg_list=['-d', '--dataset'])
    dataset_name = dataset_name if dataset_name is not None \
        else config.full_config['dataset']['default_dataset']
    result = config.get_config(dataset_name=dataset_name)[
        'env'].update(other_kwargs)
    env.clear().update(**result)
    ansi.switch(env['color'])
    if seed is None and 'seed' in env.keys():
        seed = env['seed']
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    num_gpus: int = torch.cuda.device_count()
    device: str | int | torch.device = result['device']
    if device is None:
        device = 'auto'
    match device:
        case torch.device():
            pass
        case 'auto':
            device = torch.device('cuda' if num_gpus else 'cpu')
        case 'gpu':
            device = torch.device('cuda')
        case _:
            device = torch.device(device)
    if device.type == 'cpu':
        num_gpus = 0
    if device.index is not None and torch.cuda.is_available():
        num_gpus = 1
    if cudnn_benchmark is None and 'cudnn_benchmark' in env.keys():
        cudnn_benchmark = env['cudnn_benchmark']
    if cudnn_benchmark:
        torch.backends.cudnn.benchmark = cudnn_benchmark
    env.update(seed=seed, device=device,
               cudnn_benchmark=cudnn_benchmark, num_gpus=num_gpus)

    env['world_size'] = 1   # TODO
    return env
Exemplo n.º 17
0
def create(dataset_name: None | str = None,
           dataset: None | str | Dataset = None,
           model: None | Model = None,
           model_ema: None | bool = False,
           pre_conditioner: None | str = None,
           tensorboard: None | bool = None,
           ClassType: type[Trainer] = Trainer,
           config: Config = config, **kwargs):
    r"""
    | Create a trainer instance.
    | For arguments not included in :attr:`kwargs`,
      use the default values in :attr:`config`.
    | For trainer implementation, see :class:`Trainer`.

    Args:
        dataset_name (str): The dataset name.
        dataset (str | trojanzoo.datasets.Dataset):
            Dataset instance
            (required for :attr:`model_ema`)
            or dataset name
            (as the alias of `dataset_name`).
        model (trojanzoo.models.Model): Model instance.
        model_ema (bool): Whether to use
            :class:`~trojanzoo.utils.model.ExponentialMovingAverage`.
            Defaults to ``False``.
        pre_conditioner (str): Choose from

            * ``None``
            * ``'kfac'``: :class:`~trojanzoo.utils.fim.KFAC`
            * ``'ekfac'``: :class:`~trojanzoo.utils.fim.EKFAC`

            Defaults to ``None``.
        tensorboard (bool): Whether to use
            :any:`torch.utils.tensorboard.writer.SummaryWriter`.
            Defaults to ``False``.
        ClassType (type[Trainer]): The trainer class type.
            Defaults to :class:`Trainer`.
        config (Config): The default parameter config.
        **kwargs: The keyword arguments in keys of
            ``['optim_args', 'train_args', 'writer_args']``.

    Returns:
        Trainer: The trainer instance.
    """
    assert isinstance(model, Model)
    dataset_name = get_name(name=dataset_name, module=dataset,
                            arg_list=['-d', '--dataset'])
    result = config.get_config(dataset_name=dataset_name
                               )['trainer'].update(kwargs)

    optim_keys = model.define_optimizer.__code__.co_varnames
    train_keys = model._train.__code__.co_varnames
    optim_args: dict[str, Any] = {}
    train_args: dict[str, Any] = {}
    for key, value in result.items():
        if key in optim_keys:
            optim_args[key] = value
        elif key in train_keys and key != 'verbose':
            train_args[key] = value
    train_args['epochs'] = result['epochs']
    train_args['lr_warmup_epochs'] = result['lr_warmup_epochs']

    optimizer, lr_scheduler = model.define_optimizer(**optim_args)

    module = model._model
    match optim_args['parameters']:
        case 'features':
            module = module.features
        case 'classifier':
            module = module.classifier

    # https://github.com/pytorch/vision/blob/main/references/classification/train.py
    model_ema_module = None
    if model_ema:
        # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
        # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = env['world_size'] * dataset.batch_size * \
            result['model_ema_steps'] / result['epochs']
        alpha = 1.0 - result['model_ema_decay']
        alpha = min(1.0, alpha * adjust)
        model_ema_module = ExponentialMovingAverage(
            model._model, decay=1.0 - alpha)

    match pre_conditioner:
        case 'kfac':
            kfac_optimizer = KFAC(module)
        case 'ekfac':
            kfac_optimizer = EKFAC(module)
        case _:
            kfac_optimizer = None

    writer = None
    writer_args: dict[str, Any] = {}
    if tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        # log_dir, flush_secs, ...
        writer_keys = SummaryWriter.__init__.__code__.co_varnames
        for key, value in result.items():
            if key in writer_keys:
                writer_args[key] = value
        writer = SummaryWriter(**writer_args)
    return ClassType(optim_args=optim_args, train_args=train_args,
                     writer_args=writer_args,
                     optimizer=optimizer, lr_scheduler=lr_scheduler,
                     model_ema=model_ema_module,
                     pre_conditioner=kfac_optimizer, writer=writer)