def config(): model_class = MaskEstimatorModel trainer_opts = deflatten({ 'model.factory': model_class, 'optimizer.factory': Adam, 'stop_trigger': (int(1e5), 'iteration'), 'summary_trigger': (500, 'iteration'), 'checkpoint_trigger': (500, 'iteration'), 'storage_dir': None, }) provider_opts = deflatten({ 'factory': SequenceProvider, 'database.factory': Chime3, 'audio_keys': [OBSERVATION, NOISE_IMAGE, SPEECH_IMAGE], 'transform.factory': MaskTransformer, 'transform.stft': dict(factory=STFT, shift=256, size=1024), }) trainer_opts['model']['transformer'] = provider_opts['transform'] storage_dir = None add_name = None if storage_dir is None: ex_name = get_experiment_name(trainer_opts['model']) if add_name is not None: ex_name += f'_{add_name}' observer = sacred.observers.FileStorageObserver.create( str(model_dir / ex_name)) storage_dir = observer.basedir else: sacred.observers.FileStorageObserver.create(storage_dir) trainer_opts['storage_dir'] = storage_dir if (Path(storage_dir) / 'init.json').exists(): trainer_opts, provider_opts = compare_configs(storage_dir, trainer_opts, provider_opts) Trainer.get_config(trainer_opts) Configurable.get_config(provider_opts) validate_checkpoint = 'ckpt_latest.pth' validation_kwargs = dict( metric='loss', maximize=False, max_checkpoints=1, validation_length= 1000 # number of examples taken from the validation iterator )
def main(_run, _log, trainer, database_json, training_sets, validation_sets, audio_reader, stft, max_length_in_sec, batch_size, resume): commands.print_config(_run) trainer = Trainer.from_config(trainer) storage_dir = Path(trainer.storage_dir) storage_dir.mkdir(parents=True, exist_ok=True) commands.save_config(_run.config, _log, config_filename=str(storage_dir / 'config.json')) db = JsonDatabase(database_json) training_data = db.get_dataset(training_sets) validation_data = db.get_dataset(validation_sets) training_data = prepare_dataset(training_data, audio_reader=audio_reader, stft=stft, max_length_in_sec=max_length_in_sec, batch_size=batch_size, shuffle=True) validation_data = prepare_dataset(validation_data, audio_reader=audio_reader, stft=stft, max_length_in_sec=max_length_in_sec, batch_size=batch_size, shuffle=False) trainer.test_run(training_data, validation_data) trainer.register_validation_hook(validation_data) trainer.train(training_data, resume=resume)
def train( _run, audio_reader, stft, num_workers, batch_size, max_padding_rate, trainer, resume, ): print_config(_run) trainer = Trainer.from_config(trainer) train_iter, validation_iter = get_datasets( audio_reader=audio_reader, stft=stft, num_workers=num_workers, batch_size=batch_size, max_padding_rate=max_padding_rate, storage_dir=trainer.storage_dir) trainer.test_run(train_iter, validation_iter) trainer.register_validation_hook(validation_iter, metric='macro_fscore', maximize=True) trainer.train(train_iter, resume=resume)
def config(): database_json = (str((Path(os.environ['NT_DATABASE_JSONS_DIR']) / 'librispeech.json').expanduser()) if 'NT_DATABASE_JSONS_DIR' in os.environ else None) assert database_json is not None, ( 'database_json cannot be None.\n' 'Either start the training with "python -m padertorch.contrib.examples.' 'audio_synthesis.wavenet.train with database_json=</path/to/json>" ' 'or make sure there is an environment variable "NT_DATABASE_JSONS_DIR"' 'pointing to a directory with a "librispeech.json" in it (see README ' 'for the JSON format).') training_sets = ['train_clean_100', 'train_clean_360'] validation_sets = ['dev_clean'] audio_reader = { 'source_sample_rate': 16000, 'target_sample_rate': 16000, } stft = { 'shift': 200, 'window_length': 800, 'size': 1024, 'fading': 'full', 'pad': True, } max_length_in_sec = 1. batch_size = 3 number_of_mel_filters = 80 trainer = { 'model': { 'factory': WaveNet, 'wavenet': { 'n_cond_channels': number_of_mel_filters, 'upsamp_window': stft['window_length'], 'upsamp_stride': stft['shift'], 'fading': stft['fading'], }, 'sample_rate': audio_reader['target_sample_rate'], 'stft_size': stft['size'], 'number_of_mel_filters': number_of_mel_filters, 'lowest_frequency': 50 }, 'optimizer': { 'factory': Adam, 'lr': 5e-4, }, 'storage_dir': get_new_storage_dir('wavenet', id_naming='time', mkdir=False), 'summary_trigger': (1_000, 'iteration'), 'checkpoint_trigger': (10_000, 'iteration'), 'stop_trigger': (200_000, 'iteration'), } trainer = Trainer.get_config(trainer) resume = False ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
def train(model, storage_dir): train_set, validate_set, _ = get_datasets() trainer = Trainer(model=model, optimizer=Adam(lr=5e-4), storage_dir=str(storage_dir), summary_trigger=(1000, 'iteration'), checkpoint_trigger=(10000, 'iteration'), stop_trigger=(100000, 'iteration')) trainer.test_run(train_set, validate_set) trainer.register_validation_hook(validate_set) trainer.train(train_set)
def initialize_trainer_provider(task, trainer_opts, provider_opts, _run): storage_dir = Path(trainer_opts['storage_dir']) if (storage_dir / 'init.json').exists(): assert task in ['restart', 'validate'], task elif task in ['train', 'create_checkpoint']: dump_json( dict(trainer_opts=recursive_class_to_str(trainer_opts), provider_opts=recursive_class_to_str(provider_opts)), storage_dir / 'init.json') else: raise ValueError(task, storage_dir) sacred.commands.print_config(_run) trainer = Trainer.from_config(trainer_opts) assert isinstance(trainer, Trainer) provider = config_to_instance(provider_opts) return trainer, provider
def config(): delay = 0 debug = False timestamp = timeStamped('')[1:] + ('_debug' if debug else '') group_name = timestamp database_name = 'desed' storage_dir = str(storage_root / 'strong_label_crnn' / database_name / 'training' / group_name / timestamp) init_ckpt_path = None frozen_cnn_2d_layers = 0 frozen_cnn_1d_layers = 0 # Data provider if database_name == 'desed': external_data = True batch_size = 32 data_provider = { 'factory': DESEDProvider, 'json_path': str(database_jsons_dir / 'desed_pseudo_labeled_with_external.json') if external_data else str(database_jsons_dir / 'desed_pseudo_labeled_without_external.json'), 'train_set': { 'train_weak': 10 if external_data else 20, 'train_strong': 10 if external_data else 0, 'train_synthetic20': 2, 'train_synthetic21': 1, 'train_unlabel_in_domain': 2, }, 'cached_datasets': None if debug else ['train_weak', 'train_synthetic20'], 'train_fetcher': { 'batch_size': batch_size, 'prefetch_workers': batch_size, 'min_dataset_examples_in_batch': { 'train_weak': int(3 * batch_size / 32), 'train_strong': int(6 * batch_size / 32) if external_data else 0, 'train_synthetic20': int(1 * batch_size / 32), 'train_synthetic21': int(2 * batch_size / 32), 'train_unlabel_in_domain': 0, }, }, 'storage_dir': storage_dir, } num_events = 10 DESEDProvider.get_config(data_provider) validation_set_name = 'validation' validation_ground_truth_filepath = None weak_label_crnn_hyper_params_dir = '' eval_set_name = 'eval_public' eval_ground_truth_filepath = None num_iterations = 45000 if init_ckpt_path is None else 20000 checkpoint_interval = 1000 summary_interval = 100 back_off_patience = None lr_decay_step = 30000 if back_off_patience is None else None lr_decay_factor = 1 / 5 lr_rampup_steps = 1000 if init_ckpt_path is None else None gradient_clipping = 1e10 if init_ckpt_path is None else 1 else: raise ValueError(f'Unknown database {database_name}.') # Trainer configuration net_config = 'shallow' if net_config == 'shallow': m = 1 cnn = { 'cnn_2d': { 'out_channels': [ 16 * m, 16 * m, 32 * m, 32 * m, 64 * m, 64 * m, 128 * m, 128 * m, min(256 * m, 512), ], 'pool_size': 4 * [1, (2, 1)] + [1], 'kernel_size': 3, 'norm': 'batch', 'norm_kwargs': { 'eps': 1e-3 }, 'activation_fn': 'relu', 'dropout': .0, 'output_layer': False, }, 'cnn_1d': { 'out_channels': 3 * [256 * m], 'kernel_size': 3, 'norm': 'batch', 'norm_kwargs': { 'eps': 1e-3 }, 'activation_fn': 'relu', 'dropout': .0, 'output_layer': False, }, } elif net_config == 'deep': m = 2 cnn = { 'cnn_2d': { 'out_channels': (4 * [16 * m] + 4 * [32 * m] + 4 * [64 * m] + 4 * [128 * m] + [256 * m, min(256 * m, 512)]), 'pool_size': 4 * [1, 1, 1, (2, 1)] + [1, 1], 'kernel_size': 9 * [3, 1], 'residual_connections': [ None, None, 4, None, 6, None, 8, None, 10, None, 12, None, 14, None, 16, None, None, None ], 'norm': 'batch', 'norm_kwargs': { 'eps': 1e-3 }, 'activation_fn': 'relu', 'pre_activation': True, 'dropout': .0, 'output_layer': False, }, 'cnn_1d': { 'out_channels': 8 * [256 * m], 'kernel_size': [1] + 3 * [3, 1] + [1], 'residual_connections': [None, 3, None, 5, None, 7, None, None], 'norm': 'batch', 'norm_kwargs': { 'eps': 1e-3 }, 'activation_fn': 'relu', 'pre_activation': True, 'dropout': .0, 'output_layer': False, }, } else: raise ValueError(f'Unknown net_config {net_config}') if init_ckpt_path is not None: cnn['conditional_dims'] = 0 trainer = { 'model': { 'factory': strong_label.CRNN, 'feature_extractor': { 'sample_rate': data_provider['audio_reader']['target_sample_rate'], 'stft_size': data_provider['train_transform']['stft']['size'], 'number_of_filters': 128, 'frequency_warping_fn': { 'factory': MelWarping, 'warp_factor_sampling_fn': { 'factory': LogTruncatedNormal, 'scale': .08, 'truncation': np.log(1.3), }, 'boundary_frequency_ratio_sampling_fn': { 'factory': TruncatedExponential, 'scale': .5, 'truncation': 5., }, 'highest_frequency': data_provider['audio_reader']['target_sample_rate'] / 2 }, # 'blur_sigma': .5, 'n_time_masks': 1, 'max_masked_time_steps': 70, 'max_masked_time_rate': .2, 'n_frequency_masks': 1, 'max_masked_frequency_bands': 20, 'max_masked_frequency_rate': .2, 'max_noise_scale': .2, }, 'cnn': cnn, 'rnn': { 'hidden_size': 256 * m, 'num_layers': 2, 'dropout': .0, 'output_net': { 'out_channels': [256 * m, num_events], 'kernel_size': 1, 'norm': 'batch', 'activation_fn': 'relu', 'dropout': .0, } }, 'labelwise_metrics': ('fscore_strong', ), }, 'optimizer': { 'factory': Adam, 'lr': 5e-4, 'gradient_clipping': gradient_clipping, # 'weight_decay': 1e-6, }, 'summary_trigger': (summary_interval, 'iteration'), 'checkpoint_trigger': (checkpoint_interval, 'iteration'), 'stop_trigger': (num_iterations, 'iteration'), 'storage_dir': storage_dir, } del cnn use_transformer = False if use_transformer: trainer['model']['rnn']['factory'] = TransformerStack trainer['model']['rnn']['hidden_size'] = 320 trainer['model']['rnn']['num_heads'] = 10 trainer['model']['rnn']['num_layers'] = 3 trainer['model']['rnn']['dropout'] = 0.1 Trainer.get_config(trainer) resume = False assert resume or not Path(trainer['storage_dir']).exists() ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
def train( _run, debug, data_provider, trainer, lr_rampup_steps, back_off_patience, lr_decay_step, lr_decay_factor, init_ckpt_path, frozen_cnn_2d_layers, frozen_cnn_1d_layers, resume, delay, validation_set_name, validation_ground_truth_filepath, weak_label_crnn_hyper_params_dir, eval_set_name, eval_ground_truth_filepath, ): print() print('##### Training #####') print() print_config(_run) assert (back_off_patience is None) or (lr_decay_step is None), ( back_off_patience, lr_decay_step) if delay > 0: print(f'Sleep for {delay} seconds.') time.sleep(delay) data_provider = DESEDProvider.from_config(data_provider) data_provider.train_transform.label_encoder.initialize_labels( dataset=data_provider.db.get_dataset(data_provider.validate_set), verbose=True) data_provider.test_transform.label_encoder.initialize_labels() trainer = Trainer.from_config(trainer) trainer.model.label_mapping = [] for idx, label in sorted(data_provider.train_transform.label_encoder. inverse_label_mapping.items()): assert idx == len( trainer.model.label_mapping), (idx, label, len(trainer.model.label_mapping)) trainer.model.label_mapping.append( label.replace(', ', '__').replace(' ', '').replace('(', '_').replace( ')', '_').replace("'", '')) print('Params', sum(p.numel() for p in trainer.model.parameters())) if init_ckpt_path is not None: print('Load init params') state_dict = deflatten(torch.load(init_ckpt_path, map_location='cpu')['model'], maxdepth=1) trainer.model.cnn.load_state_dict(state_dict['cnn']) if frozen_cnn_2d_layers: print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers') trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers) if frozen_cnn_1d_layers: print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers') trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers) def add_tag_condition(example): example["tag_condition"] = example["weak_targets"] return example train_set = data_provider.get_train_set().map(add_tag_condition) validate_set = data_provider.get_validate_set().map(add_tag_condition) if validate_set is not None: trainer.test_run(train_set, validate_set) trainer.register_validation_hook( validate_set, metric='macro_fscore_strong', maximize=True, ) breakpoints = [] if lr_rampup_steps is not None: breakpoints += [(0, 0.), (lr_rampup_steps, 1.)] if lr_decay_step is not None: breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)] if len(breakpoints) > 0: if isinstance(trainer.optimizer, dict): names = sorted(trainer.optimizer.keys()) else: names = [None] for name in names: trainer.register_hook( LRAnnealingHook( trigger=AllTrigger( (100, 'iteration'), NotTrigger( EndTrigger(breakpoints[-1][0] + 100, 'iteration')), ), breakpoints=breakpoints, unit='iteration', name=name, )) trainer.train(train_set, resume=resume) if validation_set_name: tuning.run( config_updates={ 'debug': debug, 'weak_label_crnn_hyper_params_dir': weak_label_crnn_hyper_params_dir, 'strong_label_crnn_dirs': [str(trainer.storage_dir)], 'validation_set_name': validation_set_name, 'validation_ground_truth_filepath': validation_ground_truth_filepath, 'eval_set_name': eval_set_name, 'eval_ground_truth_filepath': eval_ground_truth_filepath, })
def train( _run, debug, data_provider, filter_desed_test_clips, trainer, lr_rampup_steps, back_off_patience, lr_decay_step, lr_decay_factor, init_ckpt_path, frozen_cnn_2d_layers, frozen_cnn_1d_layers, track_emissions, resume, delay, validation_set_name, validation_ground_truth_filepath, eval_set_name, eval_ground_truth_filepath, ): print() print('##### Training #####') print() print_config(_run) assert (back_off_patience is None) or (lr_decay_step is None), ( back_off_patience, lr_decay_step) if delay > 0: print(f'Sleep for {delay} seconds.') time.sleep(delay) data_provider = DataProvider.from_config(data_provider) data_provider.train_transform.label_encoder.initialize_labels( dataset=data_provider.db.get_dataset(data_provider.validate_set), verbose=True) data_provider.test_transform.label_encoder.initialize_labels() trainer = Trainer.from_config(trainer) trainer.model.label_mapping = [] for idx, label in sorted(data_provider.train_transform.label_encoder. inverse_label_mapping.items()): assert idx == len( trainer.model.label_mapping), (idx, label, len(trainer.model.label_mapping)) trainer.model.label_mapping.append( label.replace(', ', '__').replace(' ', '').replace('(', '_').replace( ')', '_').replace("'", '')) print('Params', sum(p.numel() for p in trainer.model.parameters())) print('CNN Params', sum(p.numel() for p in trainer.model.cnn.parameters())) if init_ckpt_path is not None: print('Load init params') state_dict = deflatten(torch.load(init_ckpt_path, map_location='cpu')['model'], maxdepth=2) trainer.model.cnn.load_state_dict(flatten(state_dict['cnn'])) trainer.model.rnn_fwd.rnn.load_state_dict(state_dict['rnn_fwd']['rnn']) trainer.model.rnn_bwd.rnn.load_state_dict(state_dict['rnn_bwd']['rnn']) # pop output layer from checkpoint param_keys = sorted(state_dict['rnn_fwd']['output_net'].keys()) layer_idx = [key.split('.')[1] for key in param_keys] last_layer_idx = layer_idx[-1] for key, layer_idx in zip(param_keys, layer_idx): if layer_idx == last_layer_idx: state_dict['rnn_fwd']['output_net'].pop(key) state_dict['rnn_bwd']['output_net'].pop(key) trainer.model.rnn_fwd.output_net.load_state_dict( state_dict['rnn_fwd']['output_net'], strict=False) trainer.model.rnn_bwd.output_net.load_state_dict( state_dict['rnn_bwd']['output_net'], strict=False) if frozen_cnn_2d_layers: print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers') trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers) if frozen_cnn_1d_layers: print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers') trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers) if filter_desed_test_clips: with (database_jsons_dir / 'desed.json').open() as fid: desed_json = json.load(fid) filter_example_ids = { clip_id.rsplit('_', maxsplit=2)[0][1:] for clip_id in (list(desed_json['datasets']['validation'].keys()) + list(desed_json['datasets']['eval_public'].keys())) } else: filter_example_ids = None train_set = data_provider.get_train_set( filter_example_ids=filter_example_ids) validate_set = data_provider.get_validate_set() if validate_set is not None: trainer.test_run(train_set, validate_set) trainer.register_validation_hook( validate_set, metric='macro_fscore_weak', maximize=True, back_off_patience=back_off_patience, n_back_off=0 if back_off_patience is None else 1, lr_update_factor=lr_decay_factor, early_stopping_patience=back_off_patience, ) breakpoints = [] if lr_rampup_steps is not None: breakpoints += [(0, 0.), (lr_rampup_steps, 1.)] if lr_decay_step is not None: breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)] if len(breakpoints) > 0: if isinstance(trainer.optimizer, dict): names = sorted(trainer.optimizer.keys()) else: names = [None] for name in names: trainer.register_hook( LRAnnealingHook( trigger=AllTrigger( (100, 'iteration'), NotTrigger( EndTrigger(breakpoints[-1][0] + 100, 'iteration')), ), breakpoints=breakpoints, unit='iteration', name=name, )) trainer.train(train_set, resume=resume, track_emissions=track_emissions) if validation_set_name is not None: tuning.run( config_updates={ 'debug': debug, 'crnn_dirs': [str(trainer.storage_dir)], 'validation_set_name': validation_set_name, 'validation_ground_truth_filepath': validation_ground_truth_filepath, 'eval_set_name': eval_set_name, 'eval_ground_truth_filepath': eval_ground_truth_filepath, 'data_provider': { 'test_fetcher': { 'batch_size': data_provider.train_fetcher.batch_size, } }, })
def config(): resume = False # Data configuration audio_reader = { 'source_sample_rate': None, 'target_sample_rate': 44100, } stft = { 'shift': 882, 'window_length': 2 * 882, 'size': 2048, 'fading': None, 'pad': False, } batch_size = 24 num_workers = 8 prefetch_buffer = 10 * batch_size max_total_size = None max_padding_rate = 0.1 bucket_expiration = 1000 * batch_size # Trainer configuration trainer = { 'model': { 'factory': CRNN, 'feature_extractor': { 'sample_rate': audio_reader['target_sample_rate'], 'fft_length': stft['size'], 'n_mels': 128, 'warping_fn': { 'factory': MelWarping, 'alpha_sampling_fn': { 'factory': LogTruncNormalSampler, 'scale': .07, 'truncation': np.log(1.3), }, 'fhi_sampling_fn': { 'factory': TruncExponentialSampler, 'scale': .5, 'truncation': 5., }, }, 'max_resample_rate': 1., 'n_time_masks': 1, 'max_masked_time_steps': 70, 'max_masked_time_rate': .2, 'n_mel_masks': 1, 'max_masked_mel_steps': 16, 'max_masked_mel_rate': .2, 'max_noise_scale': .0, }, 'cnn_2d': { 'out_channels': [16, 16, 32, 32, 64, 64, 128, 128, 256], 'pool_size': [1, 2, 1, 2, 1, 2, 1, (2, 1), (2, 1)], # 'residual_connections': [None, 3, None, 5, None, 7, None], 'output_layer': False, 'kernel_size': 3, 'norm': 'batch', 'activation_fn': 'relu', # 'pre_activation': True, 'dropout': .0, }, 'cnn_1d': { 'out_channels': 3 * [512], # 'residual_connections': [None, 3, None], 'input_layer': False, 'output_layer': False, 'kernel_size': 3, 'norm': 'batch', 'activation_fn': 'relu', # 'pre_activation': True, 'dropout': .0, }, 'rnn_fwd': { 'hidden_size': 512, 'num_layers': 2, 'dropout': .0, }, 'clf_fwd': { 'out_channels': [512, 527], 'input_layer': False, 'kernel_size': 1, 'norm': 'batch', 'activation_fn': 'relu', 'dropout': .0, }, 'rnn_bwd': { 'hidden_size': 512, 'num_layers': 2, 'dropout': .0, }, 'clf_bwd': { 'out_channels': [512, 527], 'input_layer': False, 'kernel_size': 1, 'norm': 'batch', 'activation_fn': 'relu', 'dropout': .0, }, }, 'optimizer': { 'factory': Adam, 'lr': 3e-4, 'gradient_clipping': 20., }, 'storage_dir': storage_dir, 'summary_trigger': (100, 'iteration'), 'checkpoint_trigger': (1000, 'iteration'), 'stop_trigger': (100000, 'iteration') } Trainer.get_config(trainer)