Beispiel #1
0
def config():
    model_class = MaskEstimatorModel
    trainer_opts = deflatten({
        'model.factory': model_class,
        'optimizer.factory': Adam,
        'stop_trigger': (int(1e5), 'iteration'),
        'summary_trigger': (500, 'iteration'),
        'checkpoint_trigger': (500, 'iteration'),
        'storage_dir': None,
    })
    provider_opts = deflatten({
        'factory':
        SequenceProvider,
        'database.factory':
        Chime3,
        'audio_keys': [OBSERVATION, NOISE_IMAGE, SPEECH_IMAGE],
        'transform.factory':
        MaskTransformer,
        'transform.stft':
        dict(factory=STFT, shift=256, size=1024),
    })
    trainer_opts['model']['transformer'] = provider_opts['transform']

    storage_dir = None
    add_name = None
    if storage_dir is None:
        ex_name = get_experiment_name(trainer_opts['model'])
        if add_name is not None:
            ex_name += f'_{add_name}'
        observer = sacred.observers.FileStorageObserver.create(
            str(model_dir / ex_name))
        storage_dir = observer.basedir
    else:
        sacred.observers.FileStorageObserver.create(storage_dir)

    trainer_opts['storage_dir'] = storage_dir

    if (Path(storage_dir) / 'init.json').exists():
        trainer_opts, provider_opts = compare_configs(storage_dir,
                                                      trainer_opts,
                                                      provider_opts)

    Trainer.get_config(trainer_opts)
    Configurable.get_config(provider_opts)
    validate_checkpoint = 'ckpt_latest.pth'
    validation_kwargs = dict(
        metric='loss',
        maximize=False,
        max_checkpoints=1,
        validation_length=
        1000  # number of examples taken from the validation iterator
    )
Beispiel #2
0
def config():
    database_json = (str((Path(os.environ['NT_DATABASE_JSONS_DIR']) /
                          'librispeech.json').expanduser())
                     if 'NT_DATABASE_JSONS_DIR' in os.environ else None)
    assert database_json is not None, (
        'database_json cannot be None.\n'
        'Either start the training with "python -m padertorch.contrib.examples.'
        'audio_synthesis.wavenet.train with database_json=</path/to/json>" '
        'or make sure there is an environment variable "NT_DATABASE_JSONS_DIR"'
        'pointing to a directory with a "librispeech.json" in it (see README '
        'for the JSON format).')
    training_sets = ['train_clean_100', 'train_clean_360']
    validation_sets = ['dev_clean']
    audio_reader = {
        'source_sample_rate': 16000,
        'target_sample_rate': 16000,
    }
    stft = {
        'shift': 200,
        'window_length': 800,
        'size': 1024,
        'fading': 'full',
        'pad': True,
    }
    max_length_in_sec = 1.
    batch_size = 3
    number_of_mel_filters = 80
    trainer = {
        'model': {
            'factory': WaveNet,
            'wavenet': {
                'n_cond_channels': number_of_mel_filters,
                'upsamp_window': stft['window_length'],
                'upsamp_stride': stft['shift'],
                'fading': stft['fading'],
            },
            'sample_rate': audio_reader['target_sample_rate'],
            'stft_size': stft['size'],
            'number_of_mel_filters': number_of_mel_filters,
            'lowest_frequency': 50
        },
        'optimizer': {
            'factory': Adam,
            'lr': 5e-4,
        },
        'storage_dir':
        get_new_storage_dir('wavenet', id_naming='time', mkdir=False),
        'summary_trigger': (1_000, 'iteration'),
        'checkpoint_trigger': (10_000, 'iteration'),
        'stop_trigger': (200_000, 'iteration'),
    }
    trainer = Trainer.get_config(trainer)
    resume = False
    ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
Beispiel #3
0
def config():
    delay = 0
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')
    group_name = timestamp
    database_name = 'desed'
    storage_dir = str(storage_root / 'strong_label_crnn' / database_name /
                      'training' / group_name / timestamp)

    init_ckpt_path = None
    frozen_cnn_2d_layers = 0
    frozen_cnn_1d_layers = 0

    # Data provider
    if database_name == 'desed':
        external_data = True
        batch_size = 32
        data_provider = {
            'factory':
            DESEDProvider,
            'json_path':
            str(database_jsons_dir /
                'desed_pseudo_labeled_with_external.json') if external_data
            else str(database_jsons_dir /
                     'desed_pseudo_labeled_without_external.json'),
            'train_set': {
                'train_weak': 10 if external_data else 20,
                'train_strong': 10 if external_data else 0,
                'train_synthetic20': 2,
                'train_synthetic21': 1,
                'train_unlabel_in_domain': 2,
            },
            'cached_datasets':
            None if debug else ['train_weak', 'train_synthetic20'],
            'train_fetcher': {
                'batch_size': batch_size,
                'prefetch_workers': batch_size,
                'min_dataset_examples_in_batch': {
                    'train_weak': int(3 * batch_size / 32),
                    'train_strong':
                    int(6 * batch_size / 32) if external_data else 0,
                    'train_synthetic20': int(1 * batch_size / 32),
                    'train_synthetic21': int(2 * batch_size / 32),
                    'train_unlabel_in_domain': 0,
                },
            },
            'storage_dir':
            storage_dir,
        }
        num_events = 10
        DESEDProvider.get_config(data_provider)

        validation_set_name = 'validation'
        validation_ground_truth_filepath = None
        weak_label_crnn_hyper_params_dir = ''
        eval_set_name = 'eval_public'
        eval_ground_truth_filepath = None

        num_iterations = 45000 if init_ckpt_path is None else 20000
        checkpoint_interval = 1000
        summary_interval = 100
        back_off_patience = None
        lr_decay_step = 30000 if back_off_patience is None else None
        lr_decay_factor = 1 / 5
        lr_rampup_steps = 1000 if init_ckpt_path is None else None
        gradient_clipping = 1e10 if init_ckpt_path is None else 1
    else:
        raise ValueError(f'Unknown database {database_name}.')

    # Trainer configuration
    net_config = 'shallow'
    if net_config == 'shallow':
        m = 1
        cnn = {
            'cnn_2d': {
                'out_channels': [
                    16 * m,
                    16 * m,
                    32 * m,
                    32 * m,
                    64 * m,
                    64 * m,
                    128 * m,
                    128 * m,
                    min(256 * m, 512),
                ],
                'pool_size':
                4 * [1, (2, 1)] + [1],
                'kernel_size':
                3,
                'norm':
                'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn':
                'relu',
                'dropout':
                .0,
                'output_layer':
                False,
            },
            'cnn_1d': {
                'out_channels': 3 * [256 * m],
                'kernel_size': 3,
                'norm': 'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn': 'relu',
                'dropout': .0,
                'output_layer': False,
            },
        }
    elif net_config == 'deep':
        m = 2
        cnn = {
            'cnn_2d': {
                'out_channels':
                (4 * [16 * m] + 4 * [32 * m] + 4 * [64 * m] + 4 * [128 * m] +
                 [256 * m, min(256 * m, 512)]),
                'pool_size':
                4 * [1, 1, 1, (2, 1)] + [1, 1],
                'kernel_size':
                9 * [3, 1],
                'residual_connections': [
                    None, None, 4, None, 6, None, 8, None, 10, None, 12, None,
                    14, None, 16, None, None, None
                ],
                'norm':
                'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn':
                'relu',
                'pre_activation':
                True,
                'dropout':
                .0,
                'output_layer':
                False,
            },
            'cnn_1d': {
                'out_channels': 8 * [256 * m],
                'kernel_size': [1] + 3 * [3, 1] + [1],
                'residual_connections':
                [None, 3, None, 5, None, 7, None, None],
                'norm': 'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn': 'relu',
                'pre_activation': True,
                'dropout': .0,
                'output_layer': False,
            },
        }
    else:
        raise ValueError(f'Unknown net_config {net_config}')

    if init_ckpt_path is not None:
        cnn['conditional_dims'] = 0

    trainer = {
        'model': {
            'factory': strong_label.CRNN,
            'feature_extractor': {
                'sample_rate':
                data_provider['audio_reader']['target_sample_rate'],
                'stft_size': data_provider['train_transform']['stft']['size'],
                'number_of_filters': 128,
                'frequency_warping_fn': {
                    'factory':
                    MelWarping,
                    'warp_factor_sampling_fn': {
                        'factory': LogTruncatedNormal,
                        'scale': .08,
                        'truncation': np.log(1.3),
                    },
                    'boundary_frequency_ratio_sampling_fn': {
                        'factory': TruncatedExponential,
                        'scale': .5,
                        'truncation': 5.,
                    },
                    'highest_frequency':
                    data_provider['audio_reader']['target_sample_rate'] / 2
                },
                # 'blur_sigma': .5,
                'n_time_masks': 1,
                'max_masked_time_steps': 70,
                'max_masked_time_rate': .2,
                'n_frequency_masks': 1,
                'max_masked_frequency_bands': 20,
                'max_masked_frequency_rate': .2,
                'max_noise_scale': .2,
            },
            'cnn': cnn,
            'rnn': {
                'hidden_size': 256 * m,
                'num_layers': 2,
                'dropout': .0,
                'output_net': {
                    'out_channels': [256 * m, num_events],
                    'kernel_size': 1,
                    'norm': 'batch',
                    'activation_fn': 'relu',
                    'dropout': .0,
                }
            },
            'labelwise_metrics': ('fscore_strong', ),
        },
        'optimizer': {
            'factory': Adam,
            'lr': 5e-4,
            'gradient_clipping': gradient_clipping,
            # 'weight_decay': 1e-6,
        },
        'summary_trigger': (summary_interval, 'iteration'),
        'checkpoint_trigger': (checkpoint_interval, 'iteration'),
        'stop_trigger': (num_iterations, 'iteration'),
        'storage_dir': storage_dir,
    }
    del cnn
    use_transformer = False
    if use_transformer:
        trainer['model']['rnn']['factory'] = TransformerStack
        trainer['model']['rnn']['hidden_size'] = 320
        trainer['model']['rnn']['num_heads'] = 10
        trainer['model']['rnn']['num_layers'] = 3
        trainer['model']['rnn']['dropout'] = 0.1
    Trainer.get_config(trainer)

    resume = False
    assert resume or not Path(trainer['storage_dir']).exists()
    ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
Beispiel #4
0
def config():
    resume = False

    # Data configuration
    audio_reader = {
        'source_sample_rate': None,
        'target_sample_rate': 44100,
    }
    stft = {
        'shift': 882,
        'window_length': 2 * 882,
        'size': 2048,
        'fading': None,
        'pad': False,
    }

    batch_size = 24
    num_workers = 8
    prefetch_buffer = 10 * batch_size
    max_total_size = None
    max_padding_rate = 0.1
    bucket_expiration = 1000 * batch_size

    # Trainer configuration
    trainer = {
        'model': {
            'factory': CRNN,
            'feature_extractor': {
                'sample_rate': audio_reader['target_sample_rate'],
                'fft_length': stft['size'],
                'n_mels': 128,
                'warping_fn': {
                    'factory': MelWarping,
                    'alpha_sampling_fn': {
                        'factory': LogTruncNormalSampler,
                        'scale': .07,
                        'truncation': np.log(1.3),
                    },
                    'fhi_sampling_fn': {
                        'factory': TruncExponentialSampler,
                        'scale': .5,
                        'truncation': 5.,
                    },
                },
                'max_resample_rate': 1.,
                'n_time_masks': 1,
                'max_masked_time_steps': 70,
                'max_masked_time_rate': .2,
                'n_mel_masks': 1,
                'max_masked_mel_steps': 16,
                'max_masked_mel_rate': .2,
                'max_noise_scale': .0,
            },
            'cnn_2d': {
                'out_channels': [16, 16, 32, 32, 64, 64, 128, 128, 256],
                'pool_size': [1, 2, 1, 2, 1, 2, 1, (2, 1), (2, 1)],
                # 'residual_connections': [None, 3, None, 5, None, 7, None],
                'output_layer': False,
                'kernel_size': 3,
                'norm': 'batch',
                'activation_fn': 'relu',
                # 'pre_activation': True,
                'dropout': .0,
            },
            'cnn_1d': {
                'out_channels': 3 * [512],
                # 'residual_connections': [None, 3, None],
                'input_layer': False,
                'output_layer': False,
                'kernel_size': 3,
                'norm': 'batch',
                'activation_fn': 'relu',
                # 'pre_activation': True,
                'dropout': .0,
            },
            'rnn_fwd': {
                'hidden_size': 512,
                'num_layers': 2,
                'dropout': .0,
            },
            'clf_fwd': {
                'out_channels': [512, 527],
                'input_layer': False,
                'kernel_size': 1,
                'norm': 'batch',
                'activation_fn': 'relu',
                'dropout': .0,
            },
            'rnn_bwd': {
                'hidden_size': 512,
                'num_layers': 2,
                'dropout': .0,
            },
            'clf_bwd': {
                'out_channels': [512, 527],
                'input_layer': False,
                'kernel_size': 1,
                'norm': 'batch',
                'activation_fn': 'relu',
                'dropout': .0,
            },
        },
        'optimizer': {
            'factory': Adam,
            'lr': 3e-4,
            'gradient_clipping': 20.,
        },
        'storage_dir': storage_dir,
        'summary_trigger': (100, 'iteration'),
        'checkpoint_trigger': (1000, 'iteration'),
        'stop_trigger': (100000, 'iteration')
    }
    Trainer.get_config(trainer)