示例#1
0
 def backup_state_dict(trainer: pt.Trainer):
     state_dict = copy.deepcopy(trainer.state_dict())
     try:
         yield
     finally:
         # pass
         trainer.load_state_dict(state_dict)
示例#2
0
def main():
    model = WALNet(44100, 2048, 527)
    trainer = Trainer(model=model,
                      optimizer=optimizer.Adam(lr=3e-4, gradient_clipping=60.),
                      storage_dir=storage_dir,
                      summary_trigger=(100, 'iteration'),
                      stop_trigger=(50000, 'iteration'),
                      checkpoint_trigger=(1000, 'iteration'))
    training_data, validation_data = get_datasets(
        audio_reader=dict(source_sample_rate=44100, target_sample_rate=44100),
        stft=dict(shift=882,
                  window_length=2 * 882,
                  size=2048,
                  fading=None,
                  pad=False),
        num_workers=8,
        batch_size=24,
        max_padding_rate=.1,
        storage_dir=storage_dir)
    trainer.register_validation_hook(validation_data,
                                     metric='macro_fscore',
                                     maximize=True)

    trainer.test_run(training_data, validation_data)
    trainer.train(training_data)
示例#3
0
def main(_run, _log, trainer, database_json, training_set, validation_metric,
         maximize_metric, audio_reader, stft, num_workers, batch_size,
         max_padding_rate, resume):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    training_data, validation_data, _ = get_datasets(
        database_json=database_json,
        min_signal_length=1.5,
        audio_reader=audio_reader,
        stft=stft,
        num_workers=num_workers,
        batch_size=batch_size,
        max_padding_rate=max_padding_rate,
        training_set=training_set,
        storage_dir=storage_dir,
        stft_stretch_factor_sampling_fn=Uniform(low=0.5, high=1.5),
        stft_segment_length=audio_reader['target_sample_rate'],
        stft_segment_shuffle_prob=0.,
        mixup_probs=(1 / 2, 1 / 2),
        max_mixup_length=15.,
        min_mixup_overlap=.8,
    )

    trainer.test_run(training_data, validation_data)
    trainer.register_validation_hook(validation_data,
                                     metric=validation_metric,
                                     maximize=maximize_metric)
    trainer.train(training_data, resume=resume)
示例#4
0
def test_run(trainer, database_json, dataset, batch_size, num_speakers):
    # Perform a few training and validation steps to test whether data
    # preperation and the model are working

    trainer = Trainer.from_config(trainer)
    train_set, validate_set, _ = get_datasets(None, database_json, dataset,
                                              batch_size)
    trainer.test_run(train_set, validate_set)
示例#5
0
def train(speaker_clf):
    train_set, validate_set = get_datasets()

    trainer = Trainer(model=speaker_clf,
                      optimizer=Adam(lr=3e-4),
                      storage_dir=str(storage_dir),
                      summary_trigger=(100, 'iteration'),
                      checkpoint_trigger=(1000, 'iteration'),
                      stop_trigger=(100000, 'iteration'))
    trainer.register_validation_hook(validate_set)
    trainer.test_run(train_set, validate_set)
    trainer.train(train_set)
示例#6
0
def train(model):
    train_set, validate_set = get_datasets()
    stop_trigger = 50000
    if DEBUG:
        stop_trigger = 5000
    trainer = Trainer(model=model,
                      optimizer=Adam(lr=1e-3),
                      storage_dir=str(storage_dir),
                      summary_trigger=(100, 'iteration'),
                      checkpoint_trigger=(1000, 'iteration'),
                      stop_trigger=(stop_trigger, 'iteration'))
    trainer.register_validation_hook(validate_set)
    trainer.test_run(train_set, validate_set)
    trainer.train(train_set)
示例#7
0
def main():
    model = WALNet(128, 527)
    trainer = Trainer(
        model=model,
        optimizer=optimizer.Adam(lr=3e-4, gradient_clipping=60.),
        storage_dir=storage_dir,
        summary_trigger=(100, 'iteration'),
        stop_trigger=(20000, 'iteration'),
        checkpoint_trigger=(1000, 'iteration')
    )
    training_data, validation_data = get_datasets()
    trainer.register_validation_hook(validation_data)

    trainer.test_run(training_data, validation_data)
    trainer.train(training_data)
示例#8
0
def config():
    database_json = (str((Path(os.environ['NT_DATABASE_JSONS_DIR']) /
                          'audio_set.json').expanduser())
                     if 'NT_DATABASE_JSONS_DIR' in os.environ else None)
    assert database_json is not None, (
        'database_json cannot be None.\n'
        'Either start the training with "python -m padertorch.contrib.examples.'
        'audio_synthesis.wavenet.train with database_json=</path/to/json>" '
        'or make sure there is an environment variable "NT_DATABASE_JSONS_DIR"'
        'pointing to a directory with a "audio_set.json" in it (see README '
        'for the JSON format).')
    training_set = 'balanced_train'
    audio_reader = {
        'source_sample_rate': 44_100,
        'target_sample_rate': 44_100,
    }
    stft = {
        'shift': 882,
        'window_length': 2 * 882,
        'size': 2048,
        'fading': None,
        'pad': False,
    }
    num_workers = 8
    batch_size = 24
    max_padding_rate = .05
    trainer = {
        'model': {
            'factory': WALNet,
            'sample_rate': audio_reader['target_sample_rate'],
            'stft_size': stft['size'],
            'output_size': 527,
        },
        'optimizer': {
            'factory': Adam,
            'lr': 3e-4,
            'gradient_clipping': 60.,
        },
        'storage_dir':
        get_new_storage_dir('audio_tagging', id_naming='time', mkdir=False),
        'summary_trigger': (100, 'iteration'),
        'checkpoint_trigger': (1_000, 'iteration'),
        'stop_trigger': (50_000, 'iteration'),
    }
    trainer = Trainer.get_config(trainer)
    validation_metric = 'map'
    maximize_metric = True
    resume = False
    ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
示例#9
0
def main(_run, _log, trainer, database_json, dataset, batch_size):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    train_set, validate_set, _ = get_datasets(storage_dir, database_json,
                                              dataset, batch_size)

    # Early stopping if loss is not decreasing after three consecutive validation
    # runs. Typically around 20k iterations (13 epochs) with an accuracy >98%
    # on the test set.
    trainer.register_validation_hook(validate_set, early_stopping_patience=3)
    trainer.test_run(train_set, validate_set)
    trainer.train(train_set)
示例#10
0
def train(
        _run, trainer, device,
):
    print_config(_run)
    trainer = Trainer.from_config(trainer)
    train_iter, validate_iter, batch_norm_tuning_iter = get_datasets()
    if validate_iter is not None:
        trainer.register_validation_hook(validate_iter)
    trainer.train(train_iter, device=device)

    # finalize
    if trainer.optimizer.swa_start is not None:
        trainer.optimizer.swap_swa_sgd()
    batch_norm_update(
        trainer.model, batch_norm_tuning_iter,
        feature_key='features', device=device
    )
    torch.save(
        trainer.model.state_dict(),
        storage_dir / 'checkpoints' / 'ckpt_final.pth'
    )
示例#11
0
def defaults():
    database_json = (str(
        Path(os.environ['NT_DATABASE_JSONS_DIR']) /
        'librispeech.json') if 'NT_DATABASE_JSONS_DIR' in os.environ else None)
    assert database_json is not None, (
        'database_json cannot be None.\n'
        'Either start the training with "python -m padertorch.contrib.examples.'
        'speaker_classification.train with database_json=</path/to/json>" '
        'or export "NT_DATABASE_JSONS_DIR" which points to a directory with a '
        '"librispeech.json" prior to training start (see README for the '
        'JSON format).')
    dataset = 'train_clean_100'
    batch_size = 16
    num_speakers = 251
    trainer = {
        'model': {
            'factory': SpeakerClf,
            'feature_extractor': {
                'factory': Normalization,
                'data_format': 'bft',
                'shape': (None, 64, None),
                'statistics_axis': 'bt',
                'independent_axis': None
            },
            'cnn': {
                'factory': CNN1d,
                'in_channels': 64,
                'out_channels': 4 * [512],
                'output_layer': False,
                'kernel_size': 5,
                'norm': 'batch'
            },
            'enc': {
                'factory': GRU,
                'input_size': 512,
                'hidden_size': 256,
                'num_layers': 2,
                'batch_first': True
            },
            'fcn': {
                'factory': fully_connected_stack,
                'input_size': 256,
                'hidden_size': [256],
                'output_size': num_speakers,
                'dropout': 0.
            }
        },
        'optimizer': {
            'factory': Adam,
            'lr': 3e-4,
        },
        'storage_dir':
        get_new_storage_dir(
            # do not create when performing test_run
            'speaker_clf',
            id_naming='time',
            mkdir=False),
        'summary_trigger': (100, 'iteration'),
        'checkpoint_trigger': (1000, 'iteration'),
        'stop_trigger': (100_000, 'iteration'),
    }
    trainer = Trainer.get_config(trainer)
示例#12
0
def config():
    debug = False

    # Data configuration
    use_noisy = True
    split = 0
    relabeled = False
    fold = None
    curated_reps = 7
    mixup_probs = [1/3, 2/3]
    audio_reader = {
        'input_sample_rate': 44100,
        'target_sample_rate': 44100,
    }
    stft = {
        'frame_step': 882,
        'frame_length': 1764,
        'fft_length': 2048,
    }
    mel_transform = {
        'sample_rate': audio_reader['target_sample_rate'],
        'fft_length': stft['fft_length'],
        'n_mels': 128,
        'fmin': 50,
        'fmax': 16000,
    }
    augmenter = {
        'time_warping_factor_std': None,
        'time_warping_cutoff_std': 0.1,
        'feature_warping_factor_std': 0.07,
        'feature_warping_cutoff_std': 0.5,
        'n_time_masks': 1,
        'n_feat_masks': 1,
    }
    num_workers = 8
    batch_size = 16
    prefetch_buffer = 20 * batch_size
    max_padding_rate = 0.2
    bucket_expiration = 2000 * batch_size
    event_bucketing = True

    # Trainer/Model configuration
    trainer = {
        'model': {
            'factory': CRNN,
            'cnn_2d': {
                'factory': CNN2d,
                'in_channels': 1,
                'hidden_channels': [16, 16, 32, 32, 64, 64, 128, 128, 256],
                'pool_size': [1, 2, 1, 2, 1, 2, 1, (2, 1), (2, 1)],
                'num_layers': 9,
                'out_channels': None,
                'kernel_size': 3,
                'norm': 'batch',
                'activation': 'relu',
                'gated': False,
                'dropout': .0,
            },
            'cnn_1d': {
                'factory': CNN1d,
                'in_channels': 1024,
                'hidden_channels': 256,
                'num_layers': 3,
                'out_channels': None,
                'kernel_size': 3,
                'norm': 'batch',
                'activation': 'relu',
                'dropout': .0
            },
            'enc': {
                'factory': GRU,
                'input_size': 256,
                'hidden_size': 256,
                'num_layers': 2,
                'batch_first': True,
                'bidirectional': False,
                'dropout': 0.,
            },
            'fcn': {
                'factory': fully_connected_stack,
                'input_size': 256,
                'hidden_size': 256,
                'output_size': 80,
                'activation': 'relu',
                'dropout': 0.,
            },
            'fcn_noisy': {
                'factory': fully_connected_stack,
                'input_size': 256,
                'hidden_size': 256,
                'output_size': 80,
                'activation': 'relu',
                'dropout': 0.,
            },
            'decision_boundary': .3
        },
        'optimizer': {
            'factory': Adam,
            'lr': 3e-4,
            'gradient_clipping': 15.,
            'weight_decay': 3e-5,
            'swa_start': 750 if debug else 150000,
            'swa_freq': 50 if debug else 1000,
            'swa_lr': 3e-4,
        },
        'storage_dir': storage_dir,
        'summary_trigger': (10 if debug else 100, 'iteration'),
        'checkpoint_trigger': (500 if debug else 5000, 'iteration'),
        'stop_trigger': (1000 if debug else 200000, 'iteration'),
    }
    Trainer.get_config(trainer)

    device = 0 if torch.cuda.is_available() else 'cpu'