def backup_state_dict(trainer: pt.Trainer): state_dict = copy.deepcopy(trainer.state_dict()) try: yield finally: # pass trainer.load_state_dict(state_dict)
def main(): model = WALNet(44100, 2048, 527) trainer = Trainer(model=model, optimizer=optimizer.Adam(lr=3e-4, gradient_clipping=60.), storage_dir=storage_dir, summary_trigger=(100, 'iteration'), stop_trigger=(50000, 'iteration'), checkpoint_trigger=(1000, 'iteration')) training_data, validation_data = get_datasets( audio_reader=dict(source_sample_rate=44100, target_sample_rate=44100), stft=dict(shift=882, window_length=2 * 882, size=2048, fading=None, pad=False), num_workers=8, batch_size=24, max_padding_rate=.1, storage_dir=storage_dir) trainer.register_validation_hook(validation_data, metric='macro_fscore', maximize=True) trainer.test_run(training_data, validation_data) trainer.train(training_data)
def main(_run, _log, trainer, database_json, training_set, validation_metric, maximize_metric, audio_reader, stft, num_workers, batch_size, max_padding_rate, resume): commands.print_config(_run) trainer = Trainer.from_config(trainer) storage_dir = Path(trainer.storage_dir) storage_dir.mkdir(parents=True, exist_ok=True) commands.save_config(_run.config, _log, config_filename=str(storage_dir / 'config.json')) training_data, validation_data, _ = get_datasets( database_json=database_json, min_signal_length=1.5, audio_reader=audio_reader, stft=stft, num_workers=num_workers, batch_size=batch_size, max_padding_rate=max_padding_rate, training_set=training_set, storage_dir=storage_dir, stft_stretch_factor_sampling_fn=Uniform(low=0.5, high=1.5), stft_segment_length=audio_reader['target_sample_rate'], stft_segment_shuffle_prob=0., mixup_probs=(1 / 2, 1 / 2), max_mixup_length=15., min_mixup_overlap=.8, ) trainer.test_run(training_data, validation_data) trainer.register_validation_hook(validation_data, metric=validation_metric, maximize=maximize_metric) trainer.train(training_data, resume=resume)
def test_run(trainer, database_json, dataset, batch_size, num_speakers): # Perform a few training and validation steps to test whether data # preperation and the model are working trainer = Trainer.from_config(trainer) train_set, validate_set, _ = get_datasets(None, database_json, dataset, batch_size) trainer.test_run(train_set, validate_set)
def train(speaker_clf): train_set, validate_set = get_datasets() trainer = Trainer(model=speaker_clf, optimizer=Adam(lr=3e-4), storage_dir=str(storage_dir), summary_trigger=(100, 'iteration'), checkpoint_trigger=(1000, 'iteration'), stop_trigger=(100000, 'iteration')) trainer.register_validation_hook(validate_set) trainer.test_run(train_set, validate_set) trainer.train(train_set)
def train(model): train_set, validate_set = get_datasets() stop_trigger = 50000 if DEBUG: stop_trigger = 5000 trainer = Trainer(model=model, optimizer=Adam(lr=1e-3), storage_dir=str(storage_dir), summary_trigger=(100, 'iteration'), checkpoint_trigger=(1000, 'iteration'), stop_trigger=(stop_trigger, 'iteration')) trainer.register_validation_hook(validate_set) trainer.test_run(train_set, validate_set) trainer.train(train_set)
def main(): model = WALNet(128, 527) trainer = Trainer( model=model, optimizer=optimizer.Adam(lr=3e-4, gradient_clipping=60.), storage_dir=storage_dir, summary_trigger=(100, 'iteration'), stop_trigger=(20000, 'iteration'), checkpoint_trigger=(1000, 'iteration') ) training_data, validation_data = get_datasets() trainer.register_validation_hook(validation_data) trainer.test_run(training_data, validation_data) trainer.train(training_data)
def config(): database_json = (str((Path(os.environ['NT_DATABASE_JSONS_DIR']) / 'audio_set.json').expanduser()) if 'NT_DATABASE_JSONS_DIR' in os.environ else None) assert database_json is not None, ( 'database_json cannot be None.\n' 'Either start the training with "python -m padertorch.contrib.examples.' 'audio_synthesis.wavenet.train with database_json=</path/to/json>" ' 'or make sure there is an environment variable "NT_DATABASE_JSONS_DIR"' 'pointing to a directory with a "audio_set.json" in it (see README ' 'for the JSON format).') training_set = 'balanced_train' audio_reader = { 'source_sample_rate': 44_100, 'target_sample_rate': 44_100, } stft = { 'shift': 882, 'window_length': 2 * 882, 'size': 2048, 'fading': None, 'pad': False, } num_workers = 8 batch_size = 24 max_padding_rate = .05 trainer = { 'model': { 'factory': WALNet, 'sample_rate': audio_reader['target_sample_rate'], 'stft_size': stft['size'], 'output_size': 527, }, 'optimizer': { 'factory': Adam, 'lr': 3e-4, 'gradient_clipping': 60., }, 'storage_dir': get_new_storage_dir('audio_tagging', id_naming='time', mkdir=False), 'summary_trigger': (100, 'iteration'), 'checkpoint_trigger': (1_000, 'iteration'), 'stop_trigger': (50_000, 'iteration'), } trainer = Trainer.get_config(trainer) validation_metric = 'map' maximize_metric = True resume = False ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
def main(_run, _log, trainer, database_json, dataset, batch_size): commands.print_config(_run) trainer = Trainer.from_config(trainer) storage_dir = Path(trainer.storage_dir) storage_dir.mkdir(parents=True, exist_ok=True) commands.save_config(_run.config, _log, config_filename=str(storage_dir / 'config.json')) train_set, validate_set, _ = get_datasets(storage_dir, database_json, dataset, batch_size) # Early stopping if loss is not decreasing after three consecutive validation # runs. Typically around 20k iterations (13 epochs) with an accuracy >98% # on the test set. trainer.register_validation_hook(validate_set, early_stopping_patience=3) trainer.test_run(train_set, validate_set) trainer.train(train_set)
def train( _run, trainer, device, ): print_config(_run) trainer = Trainer.from_config(trainer) train_iter, validate_iter, batch_norm_tuning_iter = get_datasets() if validate_iter is not None: trainer.register_validation_hook(validate_iter) trainer.train(train_iter, device=device) # finalize if trainer.optimizer.swa_start is not None: trainer.optimizer.swap_swa_sgd() batch_norm_update( trainer.model, batch_norm_tuning_iter, feature_key='features', device=device ) torch.save( trainer.model.state_dict(), storage_dir / 'checkpoints' / 'ckpt_final.pth' )
def defaults(): database_json = (str( Path(os.environ['NT_DATABASE_JSONS_DIR']) / 'librispeech.json') if 'NT_DATABASE_JSONS_DIR' in os.environ else None) assert database_json is not None, ( 'database_json cannot be None.\n' 'Either start the training with "python -m padertorch.contrib.examples.' 'speaker_classification.train with database_json=</path/to/json>" ' 'or export "NT_DATABASE_JSONS_DIR" which points to a directory with a ' '"librispeech.json" prior to training start (see README for the ' 'JSON format).') dataset = 'train_clean_100' batch_size = 16 num_speakers = 251 trainer = { 'model': { 'factory': SpeakerClf, 'feature_extractor': { 'factory': Normalization, 'data_format': 'bft', 'shape': (None, 64, None), 'statistics_axis': 'bt', 'independent_axis': None }, 'cnn': { 'factory': CNN1d, 'in_channels': 64, 'out_channels': 4 * [512], 'output_layer': False, 'kernel_size': 5, 'norm': 'batch' }, 'enc': { 'factory': GRU, 'input_size': 512, 'hidden_size': 256, 'num_layers': 2, 'batch_first': True }, 'fcn': { 'factory': fully_connected_stack, 'input_size': 256, 'hidden_size': [256], 'output_size': num_speakers, 'dropout': 0. } }, 'optimizer': { 'factory': Adam, 'lr': 3e-4, }, 'storage_dir': get_new_storage_dir( # do not create when performing test_run 'speaker_clf', id_naming='time', mkdir=False), 'summary_trigger': (100, 'iteration'), 'checkpoint_trigger': (1000, 'iteration'), 'stop_trigger': (100_000, 'iteration'), } trainer = Trainer.get_config(trainer)
def config(): debug = False # Data configuration use_noisy = True split = 0 relabeled = False fold = None curated_reps = 7 mixup_probs = [1/3, 2/3] audio_reader = { 'input_sample_rate': 44100, 'target_sample_rate': 44100, } stft = { 'frame_step': 882, 'frame_length': 1764, 'fft_length': 2048, } mel_transform = { 'sample_rate': audio_reader['target_sample_rate'], 'fft_length': stft['fft_length'], 'n_mels': 128, 'fmin': 50, 'fmax': 16000, } augmenter = { 'time_warping_factor_std': None, 'time_warping_cutoff_std': 0.1, 'feature_warping_factor_std': 0.07, 'feature_warping_cutoff_std': 0.5, 'n_time_masks': 1, 'n_feat_masks': 1, } num_workers = 8 batch_size = 16 prefetch_buffer = 20 * batch_size max_padding_rate = 0.2 bucket_expiration = 2000 * batch_size event_bucketing = True # Trainer/Model configuration trainer = { 'model': { 'factory': CRNN, 'cnn_2d': { 'factory': CNN2d, 'in_channels': 1, 'hidden_channels': [16, 16, 32, 32, 64, 64, 128, 128, 256], 'pool_size': [1, 2, 1, 2, 1, 2, 1, (2, 1), (2, 1)], 'num_layers': 9, 'out_channels': None, 'kernel_size': 3, 'norm': 'batch', 'activation': 'relu', 'gated': False, 'dropout': .0, }, 'cnn_1d': { 'factory': CNN1d, 'in_channels': 1024, 'hidden_channels': 256, 'num_layers': 3, 'out_channels': None, 'kernel_size': 3, 'norm': 'batch', 'activation': 'relu', 'dropout': .0 }, 'enc': { 'factory': GRU, 'input_size': 256, 'hidden_size': 256, 'num_layers': 2, 'batch_first': True, 'bidirectional': False, 'dropout': 0., }, 'fcn': { 'factory': fully_connected_stack, 'input_size': 256, 'hidden_size': 256, 'output_size': 80, 'activation': 'relu', 'dropout': 0., }, 'fcn_noisy': { 'factory': fully_connected_stack, 'input_size': 256, 'hidden_size': 256, 'output_size': 80, 'activation': 'relu', 'dropout': 0., }, 'decision_boundary': .3 }, 'optimizer': { 'factory': Adam, 'lr': 3e-4, 'gradient_clipping': 15., 'weight_decay': 3e-5, 'swa_start': 750 if debug else 150000, 'swa_freq': 50 if debug else 1000, 'swa_lr': 3e-4, }, 'storage_dir': storage_dir, 'summary_trigger': (10 if debug else 100, 'iteration'), 'checkpoint_trigger': (500 if debug else 5000, 'iteration'), 'stop_trigger': (1000 if debug else 200000, 'iteration'), } Trainer.get_config(trainer) device = 0 if torch.cuda.is_available() else 'cpu'