コード例 #1
0
def config():
    model_class = MaskEstimatorModel
    trainer_opts = deflatten({
        'model.factory': model_class,
        'optimizer.factory': Adam,
        'stop_trigger': (int(1e5), 'iteration'),
        'summary_trigger': (500, 'iteration'),
        'checkpoint_trigger': (500, 'iteration'),
        'storage_dir': None,
    })
    provider_opts = deflatten({
        'factory':
        SequenceProvider,
        'database.factory':
        Chime3,
        'audio_keys': [OBSERVATION, NOISE_IMAGE, SPEECH_IMAGE],
        'transform.factory':
        MaskTransformer,
        'transform.stft':
        dict(factory=STFT, shift=256, size=1024),
    })
    trainer_opts['model']['transformer'] = provider_opts['transform']

    storage_dir = None
    add_name = None
    if storage_dir is None:
        ex_name = get_experiment_name(trainer_opts['model'])
        if add_name is not None:
            ex_name += f'_{add_name}'
        observer = sacred.observers.FileStorageObserver.create(
            str(model_dir / ex_name))
        storage_dir = observer.basedir
    else:
        sacred.observers.FileStorageObserver.create(storage_dir)

    trainer_opts['storage_dir'] = storage_dir

    if (Path(storage_dir) / 'init.json').exists():
        trainer_opts, provider_opts = compare_configs(storage_dir,
                                                      trainer_opts,
                                                      provider_opts)

    Trainer.get_config(trainer_opts)
    Configurable.get_config(provider_opts)
    validate_checkpoint = 'ckpt_latest.pth'
    validation_kwargs = dict(
        metric='loss',
        maximize=False,
        max_checkpoints=1,
        validation_length=
        1000  # number of examples taken from the validation iterator
    )
コード例 #2
0
ファイル: train.py プロジェクト: jensheit/padertorch
def main(_run, _log, trainer, database_json, training_sets, validation_sets,
         audio_reader, stft, max_length_in_sec, batch_size, resume):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    db = JsonDatabase(database_json)
    training_data = db.get_dataset(training_sets)
    validation_data = db.get_dataset(validation_sets)
    training_data = prepare_dataset(training_data,
                                    audio_reader=audio_reader,
                                    stft=stft,
                                    max_length_in_sec=max_length_in_sec,
                                    batch_size=batch_size,
                                    shuffle=True)
    validation_data = prepare_dataset(validation_data,
                                      audio_reader=audio_reader,
                                      stft=stft,
                                      max_length_in_sec=max_length_in_sec,
                                      batch_size=batch_size,
                                      shuffle=False)

    trainer.test_run(training_data, validation_data)
    trainer.register_validation_hook(validation_data)
    trainer.train(training_data, resume=resume)
コード例 #3
0
ファイル: advanced.py プロジェクト: yisiying/padertorch
def train(
    _run,
    audio_reader,
    stft,
    num_workers,
    batch_size,
    max_padding_rate,
    trainer,
    resume,
):

    print_config(_run)
    trainer = Trainer.from_config(trainer)

    train_iter, validation_iter = get_datasets(
        audio_reader=audio_reader,
        stft=stft,
        num_workers=num_workers,
        batch_size=batch_size,
        max_padding_rate=max_padding_rate,
        storage_dir=trainer.storage_dir)
    trainer.test_run(train_iter, validation_iter)

    trainer.register_validation_hook(validation_iter,
                                     metric='macro_fscore',
                                     maximize=True)

    trainer.train(train_iter, resume=resume)
コード例 #4
0
ファイル: train.py プロジェクト: jensheit/padertorch
def config():
    database_json = (str((Path(os.environ['NT_DATABASE_JSONS_DIR']) /
                          'librispeech.json').expanduser())
                     if 'NT_DATABASE_JSONS_DIR' in os.environ else None)
    assert database_json is not None, (
        'database_json cannot be None.\n'
        'Either start the training with "python -m padertorch.contrib.examples.'
        'audio_synthesis.wavenet.train with database_json=</path/to/json>" '
        'or make sure there is an environment variable "NT_DATABASE_JSONS_DIR"'
        'pointing to a directory with a "librispeech.json" in it (see README '
        'for the JSON format).')
    training_sets = ['train_clean_100', 'train_clean_360']
    validation_sets = ['dev_clean']
    audio_reader = {
        'source_sample_rate': 16000,
        'target_sample_rate': 16000,
    }
    stft = {
        'shift': 200,
        'window_length': 800,
        'size': 1024,
        'fading': 'full',
        'pad': True,
    }
    max_length_in_sec = 1.
    batch_size = 3
    number_of_mel_filters = 80
    trainer = {
        'model': {
            'factory': WaveNet,
            'wavenet': {
                'n_cond_channels': number_of_mel_filters,
                'upsamp_window': stft['window_length'],
                'upsamp_stride': stft['shift'],
                'fading': stft['fading'],
            },
            'sample_rate': audio_reader['target_sample_rate'],
            'stft_size': stft['size'],
            'number_of_mel_filters': number_of_mel_filters,
            'lowest_frequency': 50
        },
        'optimizer': {
            'factory': Adam,
            'lr': 5e-4,
        },
        'storage_dir':
        get_new_storage_dir('wavenet', id_naming='time', mkdir=False),
        'summary_trigger': (1_000, 'iteration'),
        'checkpoint_trigger': (10_000, 'iteration'),
        'stop_trigger': (200_000, 'iteration'),
    }
    trainer = Trainer.get_config(trainer)
    resume = False
    ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
コード例 #5
0
def train(model, storage_dir):
    train_set, validate_set, _ = get_datasets()

    trainer = Trainer(model=model,
                      optimizer=Adam(lr=5e-4),
                      storage_dir=str(storage_dir),
                      summary_trigger=(1000, 'iteration'),
                      checkpoint_trigger=(10000, 'iteration'),
                      stop_trigger=(100000, 'iteration'))

    trainer.test_run(train_set, validate_set)
    trainer.register_validation_hook(validate_set)
    trainer.train(train_set)
コード例 #6
0
def initialize_trainer_provider(task, trainer_opts, provider_opts, _run):

    storage_dir = Path(trainer_opts['storage_dir'])
    if (storage_dir / 'init.json').exists():
        assert task in ['restart', 'validate'], task
    elif task in ['train', 'create_checkpoint']:
        dump_json(
            dict(trainer_opts=recursive_class_to_str(trainer_opts),
                 provider_opts=recursive_class_to_str(provider_opts)),
            storage_dir / 'init.json')
    else:
        raise ValueError(task, storage_dir)
    sacred.commands.print_config(_run)

    trainer = Trainer.from_config(trainer_opts)
    assert isinstance(trainer, Trainer)
    provider = config_to_instance(provider_opts)
    return trainer, provider
コード例 #7
0
ファイル: training.py プロジェクト: fgnt/pb_sed
def config():
    delay = 0
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')
    group_name = timestamp
    database_name = 'desed'
    storage_dir = str(storage_root / 'strong_label_crnn' / database_name /
                      'training' / group_name / timestamp)

    init_ckpt_path = None
    frozen_cnn_2d_layers = 0
    frozen_cnn_1d_layers = 0

    # Data provider
    if database_name == 'desed':
        external_data = True
        batch_size = 32
        data_provider = {
            'factory':
            DESEDProvider,
            'json_path':
            str(database_jsons_dir /
                'desed_pseudo_labeled_with_external.json') if external_data
            else str(database_jsons_dir /
                     'desed_pseudo_labeled_without_external.json'),
            'train_set': {
                'train_weak': 10 if external_data else 20,
                'train_strong': 10 if external_data else 0,
                'train_synthetic20': 2,
                'train_synthetic21': 1,
                'train_unlabel_in_domain': 2,
            },
            'cached_datasets':
            None if debug else ['train_weak', 'train_synthetic20'],
            'train_fetcher': {
                'batch_size': batch_size,
                'prefetch_workers': batch_size,
                'min_dataset_examples_in_batch': {
                    'train_weak': int(3 * batch_size / 32),
                    'train_strong':
                    int(6 * batch_size / 32) if external_data else 0,
                    'train_synthetic20': int(1 * batch_size / 32),
                    'train_synthetic21': int(2 * batch_size / 32),
                    'train_unlabel_in_domain': 0,
                },
            },
            'storage_dir':
            storage_dir,
        }
        num_events = 10
        DESEDProvider.get_config(data_provider)

        validation_set_name = 'validation'
        validation_ground_truth_filepath = None
        weak_label_crnn_hyper_params_dir = ''
        eval_set_name = 'eval_public'
        eval_ground_truth_filepath = None

        num_iterations = 45000 if init_ckpt_path is None else 20000
        checkpoint_interval = 1000
        summary_interval = 100
        back_off_patience = None
        lr_decay_step = 30000 if back_off_patience is None else None
        lr_decay_factor = 1 / 5
        lr_rampup_steps = 1000 if init_ckpt_path is None else None
        gradient_clipping = 1e10 if init_ckpt_path is None else 1
    else:
        raise ValueError(f'Unknown database {database_name}.')

    # Trainer configuration
    net_config = 'shallow'
    if net_config == 'shallow':
        m = 1
        cnn = {
            'cnn_2d': {
                'out_channels': [
                    16 * m,
                    16 * m,
                    32 * m,
                    32 * m,
                    64 * m,
                    64 * m,
                    128 * m,
                    128 * m,
                    min(256 * m, 512),
                ],
                'pool_size':
                4 * [1, (2, 1)] + [1],
                'kernel_size':
                3,
                'norm':
                'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn':
                'relu',
                'dropout':
                .0,
                'output_layer':
                False,
            },
            'cnn_1d': {
                'out_channels': 3 * [256 * m],
                'kernel_size': 3,
                'norm': 'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn': 'relu',
                'dropout': .0,
                'output_layer': False,
            },
        }
    elif net_config == 'deep':
        m = 2
        cnn = {
            'cnn_2d': {
                'out_channels':
                (4 * [16 * m] + 4 * [32 * m] + 4 * [64 * m] + 4 * [128 * m] +
                 [256 * m, min(256 * m, 512)]),
                'pool_size':
                4 * [1, 1, 1, (2, 1)] + [1, 1],
                'kernel_size':
                9 * [3, 1],
                'residual_connections': [
                    None, None, 4, None, 6, None, 8, None, 10, None, 12, None,
                    14, None, 16, None, None, None
                ],
                'norm':
                'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn':
                'relu',
                'pre_activation':
                True,
                'dropout':
                .0,
                'output_layer':
                False,
            },
            'cnn_1d': {
                'out_channels': 8 * [256 * m],
                'kernel_size': [1] + 3 * [3, 1] + [1],
                'residual_connections':
                [None, 3, None, 5, None, 7, None, None],
                'norm': 'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn': 'relu',
                'pre_activation': True,
                'dropout': .0,
                'output_layer': False,
            },
        }
    else:
        raise ValueError(f'Unknown net_config {net_config}')

    if init_ckpt_path is not None:
        cnn['conditional_dims'] = 0

    trainer = {
        'model': {
            'factory': strong_label.CRNN,
            'feature_extractor': {
                'sample_rate':
                data_provider['audio_reader']['target_sample_rate'],
                'stft_size': data_provider['train_transform']['stft']['size'],
                'number_of_filters': 128,
                'frequency_warping_fn': {
                    'factory':
                    MelWarping,
                    'warp_factor_sampling_fn': {
                        'factory': LogTruncatedNormal,
                        'scale': .08,
                        'truncation': np.log(1.3),
                    },
                    'boundary_frequency_ratio_sampling_fn': {
                        'factory': TruncatedExponential,
                        'scale': .5,
                        'truncation': 5.,
                    },
                    'highest_frequency':
                    data_provider['audio_reader']['target_sample_rate'] / 2
                },
                # 'blur_sigma': .5,
                'n_time_masks': 1,
                'max_masked_time_steps': 70,
                'max_masked_time_rate': .2,
                'n_frequency_masks': 1,
                'max_masked_frequency_bands': 20,
                'max_masked_frequency_rate': .2,
                'max_noise_scale': .2,
            },
            'cnn': cnn,
            'rnn': {
                'hidden_size': 256 * m,
                'num_layers': 2,
                'dropout': .0,
                'output_net': {
                    'out_channels': [256 * m, num_events],
                    'kernel_size': 1,
                    'norm': 'batch',
                    'activation_fn': 'relu',
                    'dropout': .0,
                }
            },
            'labelwise_metrics': ('fscore_strong', ),
        },
        'optimizer': {
            'factory': Adam,
            'lr': 5e-4,
            'gradient_clipping': gradient_clipping,
            # 'weight_decay': 1e-6,
        },
        'summary_trigger': (summary_interval, 'iteration'),
        'checkpoint_trigger': (checkpoint_interval, 'iteration'),
        'stop_trigger': (num_iterations, 'iteration'),
        'storage_dir': storage_dir,
    }
    del cnn
    use_transformer = False
    if use_transformer:
        trainer['model']['rnn']['factory'] = TransformerStack
        trainer['model']['rnn']['hidden_size'] = 320
        trainer['model']['rnn']['num_heads'] = 10
        trainer['model']['rnn']['num_layers'] = 3
        trainer['model']['rnn']['dropout'] = 0.1
    Trainer.get_config(trainer)

    resume = False
    assert resume or not Path(trainer['storage_dir']).exists()
    ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
コード例 #8
0
ファイル: training.py プロジェクト: fgnt/pb_sed
def train(
    _run,
    debug,
    data_provider,
    trainer,
    lr_rampup_steps,
    back_off_patience,
    lr_decay_step,
    lr_decay_factor,
    init_ckpt_path,
    frozen_cnn_2d_layers,
    frozen_cnn_1d_layers,
    resume,
    delay,
    validation_set_name,
    validation_ground_truth_filepath,
    weak_label_crnn_hyper_params_dir,
    eval_set_name,
    eval_ground_truth_filepath,
):
    print()
    print('##### Training #####')
    print()
    print_config(_run)
    assert (back_off_patience is None) or (lr_decay_step is None), (
        back_off_patience, lr_decay_step)
    if delay > 0:
        print(f'Sleep for {delay} seconds.')
        time.sleep(delay)

    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.train_transform.label_encoder.initialize_labels(
        dataset=data_provider.db.get_dataset(data_provider.validate_set),
        verbose=True)
    data_provider.test_transform.label_encoder.initialize_labels()
    trainer = Trainer.from_config(trainer)
    trainer.model.label_mapping = []
    for idx, label in sorted(data_provider.train_transform.label_encoder.
                             inverse_label_mapping.items()):
        assert idx == len(
            trainer.model.label_mapping), (idx, label,
                                           len(trainer.model.label_mapping))
        trainer.model.label_mapping.append(
            label.replace(', ', '__').replace(' ',
                                              '').replace('(', '_').replace(
                                                  ')', '_').replace("'", ''))
    print('Params', sum(p.numel() for p in trainer.model.parameters()))

    if init_ckpt_path is not None:
        print('Load init params')
        state_dict = deflatten(torch.load(init_ckpt_path,
                                          map_location='cpu')['model'],
                               maxdepth=1)
        trainer.model.cnn.load_state_dict(state_dict['cnn'])
    if frozen_cnn_2d_layers:
        print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers')
        trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers)
    if frozen_cnn_1d_layers:
        print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers')
        trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers)

    def add_tag_condition(example):
        example["tag_condition"] = example["weak_targets"]
        return example

    train_set = data_provider.get_train_set().map(add_tag_condition)
    validate_set = data_provider.get_validate_set().map(add_tag_condition)

    if validate_set is not None:
        trainer.test_run(train_set, validate_set)
        trainer.register_validation_hook(
            validate_set,
            metric='macro_fscore_strong',
            maximize=True,
        )

    breakpoints = []
    if lr_rampup_steps is not None:
        breakpoints += [(0, 0.), (lr_rampup_steps, 1.)]
    if lr_decay_step is not None:
        breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)]
    if len(breakpoints) > 0:
        if isinstance(trainer.optimizer, dict):
            names = sorted(trainer.optimizer.keys())
        else:
            names = [None]
        for name in names:
            trainer.register_hook(
                LRAnnealingHook(
                    trigger=AllTrigger(
                        (100, 'iteration'),
                        NotTrigger(
                            EndTrigger(breakpoints[-1][0] + 100, 'iteration')),
                    ),
                    breakpoints=breakpoints,
                    unit='iteration',
                    name=name,
                ))
    trainer.train(train_set, resume=resume)

    if validation_set_name:
        tuning.run(
            config_updates={
                'debug': debug,
                'weak_label_crnn_hyper_params_dir':
                weak_label_crnn_hyper_params_dir,
                'strong_label_crnn_dirs': [str(trainer.storage_dir)],
                'validation_set_name': validation_set_name,
                'validation_ground_truth_filepath':
                validation_ground_truth_filepath,
                'eval_set_name': eval_set_name,
                'eval_ground_truth_filepath': eval_ground_truth_filepath,
            })
コード例 #9
0
ファイル: training.py プロジェクト: fgnt/pb_sed
def train(
    _run,
    debug,
    data_provider,
    filter_desed_test_clips,
    trainer,
    lr_rampup_steps,
    back_off_patience,
    lr_decay_step,
    lr_decay_factor,
    init_ckpt_path,
    frozen_cnn_2d_layers,
    frozen_cnn_1d_layers,
    track_emissions,
    resume,
    delay,
    validation_set_name,
    validation_ground_truth_filepath,
    eval_set_name,
    eval_ground_truth_filepath,
):
    print()
    print('##### Training #####')
    print()
    print_config(_run)
    assert (back_off_patience is None) or (lr_decay_step is None), (
        back_off_patience, lr_decay_step)
    if delay > 0:
        print(f'Sleep for {delay} seconds.')
        time.sleep(delay)

    data_provider = DataProvider.from_config(data_provider)
    data_provider.train_transform.label_encoder.initialize_labels(
        dataset=data_provider.db.get_dataset(data_provider.validate_set),
        verbose=True)
    data_provider.test_transform.label_encoder.initialize_labels()
    trainer = Trainer.from_config(trainer)
    trainer.model.label_mapping = []
    for idx, label in sorted(data_provider.train_transform.label_encoder.
                             inverse_label_mapping.items()):
        assert idx == len(
            trainer.model.label_mapping), (idx, label,
                                           len(trainer.model.label_mapping))
        trainer.model.label_mapping.append(
            label.replace(', ', '__').replace(' ',
                                              '').replace('(', '_').replace(
                                                  ')', '_').replace("'", ''))
    print('Params', sum(p.numel() for p in trainer.model.parameters()))
    print('CNN Params', sum(p.numel() for p in trainer.model.cnn.parameters()))

    if init_ckpt_path is not None:
        print('Load init params')
        state_dict = deflatten(torch.load(init_ckpt_path,
                                          map_location='cpu')['model'],
                               maxdepth=2)
        trainer.model.cnn.load_state_dict(flatten(state_dict['cnn']))
        trainer.model.rnn_fwd.rnn.load_state_dict(state_dict['rnn_fwd']['rnn'])
        trainer.model.rnn_bwd.rnn.load_state_dict(state_dict['rnn_bwd']['rnn'])
        # pop output layer from checkpoint
        param_keys = sorted(state_dict['rnn_fwd']['output_net'].keys())
        layer_idx = [key.split('.')[1] for key in param_keys]
        last_layer_idx = layer_idx[-1]
        for key, layer_idx in zip(param_keys, layer_idx):
            if layer_idx == last_layer_idx:
                state_dict['rnn_fwd']['output_net'].pop(key)
                state_dict['rnn_bwd']['output_net'].pop(key)
        trainer.model.rnn_fwd.output_net.load_state_dict(
            state_dict['rnn_fwd']['output_net'], strict=False)
        trainer.model.rnn_bwd.output_net.load_state_dict(
            state_dict['rnn_bwd']['output_net'], strict=False)
    if frozen_cnn_2d_layers:
        print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers')
        trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers)
    if frozen_cnn_1d_layers:
        print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers')
        trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers)

    if filter_desed_test_clips:
        with (database_jsons_dir / 'desed.json').open() as fid:
            desed_json = json.load(fid)
        filter_example_ids = {
            clip_id.rsplit('_', maxsplit=2)[0][1:]
            for clip_id in (list(desed_json['datasets']['validation'].keys()) +
                            list(desed_json['datasets']['eval_public'].keys()))
        }
    else:
        filter_example_ids = None
    train_set = data_provider.get_train_set(
        filter_example_ids=filter_example_ids)
    validate_set = data_provider.get_validate_set()

    if validate_set is not None:
        trainer.test_run(train_set, validate_set)
        trainer.register_validation_hook(
            validate_set,
            metric='macro_fscore_weak',
            maximize=True,
            back_off_patience=back_off_patience,
            n_back_off=0 if back_off_patience is None else 1,
            lr_update_factor=lr_decay_factor,
            early_stopping_patience=back_off_patience,
        )

    breakpoints = []
    if lr_rampup_steps is not None:
        breakpoints += [(0, 0.), (lr_rampup_steps, 1.)]
    if lr_decay_step is not None:
        breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)]
    if len(breakpoints) > 0:
        if isinstance(trainer.optimizer, dict):
            names = sorted(trainer.optimizer.keys())
        else:
            names = [None]
        for name in names:
            trainer.register_hook(
                LRAnnealingHook(
                    trigger=AllTrigger(
                        (100, 'iteration'),
                        NotTrigger(
                            EndTrigger(breakpoints[-1][0] + 100, 'iteration')),
                    ),
                    breakpoints=breakpoints,
                    unit='iteration',
                    name=name,
                ))
    trainer.train(train_set, resume=resume, track_emissions=track_emissions)

    if validation_set_name is not None:
        tuning.run(
            config_updates={
                'debug': debug,
                'crnn_dirs': [str(trainer.storage_dir)],
                'validation_set_name': validation_set_name,
                'validation_ground_truth_filepath':
                validation_ground_truth_filepath,
                'eval_set_name': eval_set_name,
                'eval_ground_truth_filepath': eval_ground_truth_filepath,
                'data_provider': {
                    'test_fetcher': {
                        'batch_size': data_provider.train_fetcher.batch_size,
                    }
                },
            })
コード例 #10
0
ファイル: advanced.py プロジェクト: yisiying/padertorch
def config():
    resume = False

    # Data configuration
    audio_reader = {
        'source_sample_rate': None,
        'target_sample_rate': 44100,
    }
    stft = {
        'shift': 882,
        'window_length': 2 * 882,
        'size': 2048,
        'fading': None,
        'pad': False,
    }

    batch_size = 24
    num_workers = 8
    prefetch_buffer = 10 * batch_size
    max_total_size = None
    max_padding_rate = 0.1
    bucket_expiration = 1000 * batch_size

    # Trainer configuration
    trainer = {
        'model': {
            'factory': CRNN,
            'feature_extractor': {
                'sample_rate': audio_reader['target_sample_rate'],
                'fft_length': stft['size'],
                'n_mels': 128,
                'warping_fn': {
                    'factory': MelWarping,
                    'alpha_sampling_fn': {
                        'factory': LogTruncNormalSampler,
                        'scale': .07,
                        'truncation': np.log(1.3),
                    },
                    'fhi_sampling_fn': {
                        'factory': TruncExponentialSampler,
                        'scale': .5,
                        'truncation': 5.,
                    },
                },
                'max_resample_rate': 1.,
                'n_time_masks': 1,
                'max_masked_time_steps': 70,
                'max_masked_time_rate': .2,
                'n_mel_masks': 1,
                'max_masked_mel_steps': 16,
                'max_masked_mel_rate': .2,
                'max_noise_scale': .0,
            },
            'cnn_2d': {
                'out_channels': [16, 16, 32, 32, 64, 64, 128, 128, 256],
                'pool_size': [1, 2, 1, 2, 1, 2, 1, (2, 1), (2, 1)],
                # 'residual_connections': [None, 3, None, 5, None, 7, None],
                'output_layer': False,
                'kernel_size': 3,
                'norm': 'batch',
                'activation_fn': 'relu',
                # 'pre_activation': True,
                'dropout': .0,
            },
            'cnn_1d': {
                'out_channels': 3 * [512],
                # 'residual_connections': [None, 3, None],
                'input_layer': False,
                'output_layer': False,
                'kernel_size': 3,
                'norm': 'batch',
                'activation_fn': 'relu',
                # 'pre_activation': True,
                'dropout': .0,
            },
            'rnn_fwd': {
                'hidden_size': 512,
                'num_layers': 2,
                'dropout': .0,
            },
            'clf_fwd': {
                'out_channels': [512, 527],
                'input_layer': False,
                'kernel_size': 1,
                'norm': 'batch',
                'activation_fn': 'relu',
                'dropout': .0,
            },
            'rnn_bwd': {
                'hidden_size': 512,
                'num_layers': 2,
                'dropout': .0,
            },
            'clf_bwd': {
                'out_channels': [512, 527],
                'input_layer': False,
                'kernel_size': 1,
                'norm': 'batch',
                'activation_fn': 'relu',
                'dropout': .0,
            },
        },
        'optimizer': {
            'factory': Adam,
            'lr': 3e-4,
            'gradient_clipping': 20.,
        },
        'storage_dir': storage_dir,
        'summary_trigger': (100, 'iteration'),
        'checkpoint_trigger': (1000, 'iteration'),
        'stop_trigger': (100000, 'iteration')
    }
    Trainer.get_config(trainer)