def run(_run, storage_dir, job_id, number_of_jobs, session_id, test_run=False): print_config(_run) assert job_id >= 1 and job_id <= number_of_jobs, (job_id, number_of_jobs) enhancer = get_enhancer() if test_run: print('Database', enhancer.db) if test_run is False: dataset_slice = slice(job_id - 1, None, number_of_jobs) else: dataset_slice = test_run if mpi.IS_MASTER: print('Enhancer:', enhancer) print(session_id) enhancer.enhance_session( session_id, Path(storage_dir) / 'audio', dataset_slice=dataset_slice, audio_dir_exist_ok=True, ) if mpi.IS_MASTER: print('Finished experiment dir:', storage_dir)
def run(_run, test_run=False): if mpi.IS_MASTER: print_config(_run) _dir = get_dir() print('Experiment dir:', _dir) else: _dir = None _dir = mpi.bcast(_dir, mpi.MASTER) enhancer = get_enhancer() if test_run: print('Database', enhancer.db) session_ids = get_session_ids() if mpi.IS_MASTER: print('Enhancer:', enhancer) print(session_ids) enhancer.enhance_session( session_ids, _dir / 'audio', test_run=test_run, ) if mpi.IS_MASTER: print('Finished experiment dir:', _dir)
def train( _run, audio_reader, stft, num_workers, batch_size, max_padding_rate, trainer, resume, ): print_config(_run) trainer = Trainer.from_config(trainer) train_iter, validation_iter = get_datasets( audio_reader=audio_reader, stft=stft, num_workers=num_workers, batch_size=batch_size, max_padding_rate=max_padding_rate, storage_dir=trainer.storage_dir) trainer.test_run(train_iter, validation_iter) trainer.register_validation_hook(validation_iter, metric='macro_fscore', maximize=True) trainer.train(train_iter, resume=resume)
def main(_run, _log, trainer, database_json, training_sets, validation_sets, audio_reader, stft, max_length_in_sec, batch_size, resume): commands.print_config(_run) trainer = Trainer.from_config(trainer) storage_dir = Path(trainer.storage_dir) storage_dir.mkdir(parents=True, exist_ok=True) commands.save_config(_run.config, _log, config_filename=str(storage_dir / 'config.json')) db = JsonDatabase(database_json) training_data = db.get_dataset(training_sets) validation_data = db.get_dataset(validation_sets) training_data = prepare_dataset(training_data, audio_reader=audio_reader, stft=stft, max_length_in_sec=max_length_in_sec, batch_size=batch_size, shuffle=True) validation_data = prepare_dataset(validation_data, audio_reader=audio_reader, stft=stft, max_length_in_sec=max_length_in_sec, batch_size=batch_size, shuffle=False) trainer.test_run(training_data, validation_data) trainer.register_validation_hook(validation_data) trainer.train(training_data, resume=resume)
def run(_run, chime6, test_run=False): if dlp_mpi.IS_MASTER: print_config(_run) _dir = get_dir() print('Experiment dir:', _dir) else: _dir = None _dir = dlp_mpi.bcast(_dir, dlp_mpi.MASTER) if chime6: enhancer = get_enhancer_chime6() else: enhancer = get_enhancer() if test_run: print('Database', enhancer.db) session_ids = get_session_ids() if dlp_mpi.IS_MASTER: print('Enhancer:', enhancer) print(session_ids) enhancer.enhance_session(session_ids, _dir / 'audio', dataset_slice=test_run, audio_dir_exist_ok=True) if dlp_mpi.IS_MASTER: print('Finished experiment dir:', _dir)
def main(_run, _log, trainer, database_json, training_set, validation_metric, maximize_metric, audio_reader, stft, num_workers, batch_size, max_padding_rate, resume): commands.print_config(_run) trainer = Trainer.from_config(trainer) storage_dir = Path(trainer.storage_dir) storage_dir.mkdir(parents=True, exist_ok=True) commands.save_config(_run.config, _log, config_filename=str(storage_dir / 'config.json')) training_data, validation_data, _ = get_datasets( database_json=database_json, min_signal_length=1.5, audio_reader=audio_reader, stft=stft, num_workers=num_workers, batch_size=batch_size, max_padding_rate=max_padding_rate, training_set=training_set, storage_dir=storage_dir, stft_stretch_factor_sampling_fn=Uniform(low=0.5, high=1.5), stft_segment_length=audio_reader['target_sample_rate'], stft_segment_shuffle_prob=0., mixup_probs=(1 / 2, 1 / 2), max_mixup_length=15., min_mixup_overlap=.8, ) trainer.test_run(training_data, validation_data) trainer.register_validation_hook(validation_data, metric=validation_metric, maximize=maximize_metric) trainer.train(training_data, resume=resume)
def train(logdir, device, iterations, resume_iteration, checkpoint_interval, batch_size, sequence_length, model_complexity, learning_rate, learning_rate_decay_steps, learning_rate_decay_rate, leave_one_out, clip_gradient_norm, validation_length, validation_interval): print_config(ex.current_run) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) train_groups, validation_groups = ['train'], ['validation'] if leave_one_out is not None: all_years = {'2004', '2006', '2008', '2009', '2011', '2013', '2014', '2015', '2017'} train_groups = list(all_years - {str(leave_one_out)}) validation_groups = [str(leave_one_out)] dataset = MAESTRO(groups=train_groups, sequence_length=sequence_length) loader = DataLoader(dataset, batch_size, shuffle=True) validation_dataset = MAESTRO(groups=validation_groups, sequence_length=validation_length) if resume_iteration is None: model = OnsetsAndFrames(N_MELS, MAX_MIDI - MIN_MIDI + 1, model_complexity).to(device) optimizer = torch.optim.Adam(model.parameters(), learning_rate) resume_iteration = 0 else: model_path = os.path.join(logdir, f'model-{resume_iteration}.pt') model = torch.load(model_path) optimizer = torch.optim.Adam(model.parameters(), learning_rate) optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt'))) summary(model) scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate) loop = tqdm(range(resume_iteration + 1, iterations + 1)) for i, batch in zip(loop, cycle(loader)): scheduler.step() predictions, losses = model.run_on_batch(batch) loss = sum(losses.values()) optimizer.zero_grad() loss.backward() optimizer.step() if clip_gradient_norm: clip_grad_norm_(model.parameters(), clip_gradient_norm) for key, value in {'loss': loss, **losses}.items(): writer.add_scalar(key, value.item(), global_step=i) if i % validation_interval == 0: model.eval() with torch.no_grad(): for key, value in evaluate(validation_dataset, model).items(): writer.add_scalar('validation/' + key.replace(' ', '_'), np.mean(value), global_step=i) model.train() if i % checkpoint_interval == 0: torch.save(model, os.path.join(logdir, f'model-{i}.pt')) torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt'))
def main(_run, seed, save_folder, config_filename): set_global_seeds(seed) logger.info("Run id: {}".format(_run._id)) print_config(ex.current_run) # saving config save_config(ex.current_run.config, ex.logger, config_filename=save_folder + config_filename) train()
def main(_run, out): if dlp_mpi.IS_MASTER: from sacred.commands import print_config print_config(_run) ds = get_dataset() data = [] for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True): for prediction in [ 'source', 'early_0', 'early_1', 'image_0', 'image_1', 'image_0_noise', 'image_1_noise', ]: for source in [ 'source', 'early_0', 'early_1', 'image_0', 'image_1', 'image_0_noise', 'image_1_noise', ]: scores = get_scores(ex, prediction=prediction, source=source) for score_name, score_value in scores.items(): data.append( dict( score_name=score_name, prediction=prediction, source=source, example_id=ex['example_id'], value=score_value, )) data = dlp_mpi.gather(data) if dlp_mpi.IS_MASTER: data = [entry for worker_data in data for entry in worker_data] if out is not None: assert isinstance(out, str), out assert out.endswith('.json'), out print(f'Write details to {out}.') dump_json(data, out) summary(data)
def main(_run, _log, trainer, database_json, dataset, batch_size): commands.print_config(_run) trainer = Trainer.from_config(trainer) storage_dir = Path(trainer.storage_dir) storage_dir.mkdir(parents=True, exist_ok=True) commands.save_config(_run.config, _log, config_filename=str(storage_dir / 'config.json')) train_set, validate_set, _ = get_datasets(storage_dir, database_json, dataset, batch_size) # Early stopping if loss is not decreasing after three consecutive validation # runs. Typically around 20k iterations (13 epochs) with an accuracy >98% # on the test set. trainer.register_validation_hook(validate_set, early_stopping_patience=3) trainer.test_run(train_set, validate_set) trainer.train(train_set)
def train( _run, trainer, device, ): print_config(_run) trainer = Trainer.from_config(trainer) train_iter, validate_iter, batch_norm_tuning_iter = get_datasets() if validate_iter is not None: trainer.register_validation_hook(validate_iter) trainer.train(train_iter, device=device) # finalize if trainer.optimizer.swa_start is not None: trainer.optimizer.swap_swa_sgd() batch_norm_update( trainer.model, batch_norm_tuning_iter, feature_key='features', device=device ) torch.save( trainer.model.state_dict(), storage_dir / 'checkpoints' / 'ckpt_final.pth' )
def run(_run, storage_dir, job_id, number_of_jobs, session_id, test_run=False): if dlp_mpi.IS_MASTER: print_config(_run) assert job_id >= 1 and job_id <= number_of_jobs, (job_id, number_of_jobs) enhancer = get_enhancer() if test_run: print('Database', enhancer.db) if test_run is False: dataset_slice = slice(job_id - 1, None, number_of_jobs) else: dataset_slice = test_run if dlp_mpi.IS_MASTER: print('Enhancer:', enhancer) print(session_id) if session_id is None: session_ids = sorted(get_sessions()) elif isinstance(session_id, str): session_ids = [session_id] elif isinstance(session_id, (tuple, list)): session_ids = session_id else: raise TypeError(type(session_id), session_id) for session_id in session_ids: enhancer.enhance_session( session_id, Path(storage_dir) / 'audio', dataset_slice=dataset_slice, audio_dir_exist_ok=True, is_chime=False, ) if dlp_mpi.IS_MASTER: print('Finished experiment dir:', storage_dir)
def main(_run, exp_dir, storage_dir, database_json, test_set, max_examples, device): if IS_MASTER: commands.print_config(_run) exp_dir = Path(exp_dir) storage_dir = Path(storage_dir) audio_dir = storage_dir / 'audio' audio_dir.mkdir(parents=True) config = load_json(exp_dir / 'config.json') model = Model.from_storage_dir(exp_dir, consider_mpi=True) model.to(device) model.eval() db = JsonDatabase(database_json) test_data = db.get_dataset(test_set) if max_examples is not None: test_data = test_data.shuffle( rng=np.random.RandomState(0))[:max_examples] test_data = prepare_dataset(test_data, audio_reader=config['audio_reader'], stft=config['stft'], max_length=None, batch_size=1, shuffle=True) squared_err = list() with torch.no_grad(): for example in split_managed(test_data, is_indexable=False, progress_bar=True, allow_single_worker=True): example = model.example_to_device(example, device) target = example['audio_data'].squeeze(1) x = model.feature_extraction(example['stft'], example['seq_len']) x = model.wavenet.infer( x.squeeze(1), chunk_length=80_000, chunk_overlap=16_000, ) assert target.shape == x.shape, (target.shape, x.shape) squared_err.extend([(ex_id, mse.cpu().detach().numpy(), x.shape[1]) for ex_id, mse in zip(example['example_id'], (( x - target)**2).sum(1))]) squared_err_list = COMM.gather(squared_err, root=MASTER) if IS_MASTER: print(f'\nlen(squared_err_list): {len(squared_err_list)}') squared_err = [] for i in range(len(squared_err_list)): squared_err.extend(squared_err_list[i]) _, err, t = list(zip(*squared_err)) print('rmse:', np.sqrt(np.sum(err) / np.sum(t))) rmse = sorted([(ex_id, np.sqrt(err / t)) for ex_id, err, t in squared_err], key=lambda x: x[1]) dump_json(rmse, storage_dir / 'rmse.json', indent=4, sort_keys=False) ex_ids_ordered = [x[0] for x in rmse] test_data = db.get_dataset('test_clean').shuffle( rng=np.random.RandomState(0))[:max_examples].filter(lambda x: x[ 'example_id'] in ex_ids_ordered[:10] + ex_ids_ordered[-10:], lazy=False) test_data = prepare_dataset(test_data, audio_reader=config['audio_reader'], stft=config['stft'], max_length=10., batch_size=1, shuffle=True) with torch.no_grad(): for example in test_data: example = model.example_to_device(example, device) x = model.feature_extraction(example['stft'], example['seq_len']) x = model.wavenet.infer( x.squeeze(1), chunk_length=80_000, chunk_overlap=16_000, ) for i, audio in enumerate(x.cpu().detach().numpy()): wavfile.write( str(audio_dir / f'{example["example_id"][i]}.wav'), model.sample_rate, audio)
def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size, fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs, train_only_in_receptive_field, _run, use_ulaw, train_with_soft_target_stdev): if run_dir is None: if not os.path.exists("models"): os.mkdir("models") run_dir = os.path.join('models', datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S')) _config['run_dir'] = run_dir print_config(_run) _log.info('Running with seed %d' % seed) checkpoint_dir = os.path.join(run_dir, 'checkpoints') if not debug: if not os.path.exists(run_dir): os.mkdir(run_dir) os.mkdir(checkpoint_dir) json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w')) _log.info('Loading data...') data_generators, nb_examples = get_generators() _log.info('Building model...') model = build_model(fragment_length) _log.info(model.summary()) optim = make_optimizer() _log.info('Compiling Model...') loss = objectives.categorical_crossentropy all_metrics = [ metrics.categorical_accuracy, categorical_mean_squared_error ] if train_with_soft_target_stdev: loss = make_targets_soft(loss) if train_only_in_receptive_field: loss = skip_out_of_receptive_field(loss) all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics] model.compile(optimizer=optim, loss=loss, metrics=all_metrics) # TODO: Consider gradient weighting making last outputs more important. callbacks = [ ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1), EarlyStopping(patience=early_stopping_patience, verbose=1), ] if not debug: callbacks.extend([ ModelCheckpoint(os.path.join(checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'), save_best_only=True), CSVLogger(os.path.join(run_dir, 'history.csv')), ]) checkpoints = sorted(os.listdir(checkpoint_dir)) if checkpoints: last_checkpoint = checkpoints[-1] _log.info('Loading existing weights ', os.path.join(checkpoint_dir, last_checkpoint)) model.load_weights(os.path.join(checkpoint_dir, last_checkpoint)) model.fit_generator(data_generators['train'], len(data_generators['train']), epochs=1, validation_data=data_generators['test'], validation_steps=5, callbacks=callbacks, use_multiprocessing=False, verbose=keras_verbose)
def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size, fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs, train_only_in_receptive_field, _run, use_ulaw): if run_dir is None: run_dir = os.path.join( 'models', datetime.datetime.now().strftime('run_%Y-%m-%d_%H:%M:%S')) _config['run_dir'] = run_dir print_config(_run) _log.info('Running with seed %d' % seed) if not debug: if os.path.exists(run_dir): raise EnvironmentError('Run with seed %d already exists' % seed) os.mkdir(run_dir) checkpoint_dir = os.path.join(run_dir, 'checkpoints') json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w')) _log.info('Loading data...') data_generators, nb_examples = dataset.generators( data_dir, desired_sample_rate, fragment_length, batch_size, fragment_stride, nb_output_bins, learn_all_outputs, use_ulaw) _log.info('Building model...') model = build_model(fragment_length) _log.info(model.summary()) optim = make_optimizer() _log.info('Compiling Model...') loss = objectives.categorical_crossentropy all_metrics = [ metrics.categorical_accuracy, metrics.categorical_mean_squared_error ] if train_only_in_receptive_field: loss = skip_out_of_receptive_field(loss) all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics] model.compile(optimizer=optim, loss=loss, metrics=all_metrics) # TODO: Consider gradient weighting making last outputs more important. callbacks = [ ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1), EarlyStopping(patience=early_stopping_patience, verbose=1), ] if not debug: callbacks.extend([ ModelCheckpoint(os.path.join( checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'), save_best_only=True), CSVLogger(os.path.join(run_dir, 'history.csv')), ]) if not debug: os.mkdir(checkpoint_dir) _log.info('Starting Training...') model.fit_generator(data_generators['train'], nb_examples['train'], nb_epoch=nb_epoch, validation_data=data_generators['test'], nb_val_samples=nb_examples['test'], callbacks=callbacks, verbose=keras_verbose)
def train( _run, debug, data_provider, filter_desed_test_clips, trainer, lr_rampup_steps, back_off_patience, lr_decay_step, lr_decay_factor, init_ckpt_path, frozen_cnn_2d_layers, frozen_cnn_1d_layers, track_emissions, resume, delay, validation_set_name, validation_ground_truth_filepath, eval_set_name, eval_ground_truth_filepath, ): print() print('##### Training #####') print() print_config(_run) assert (back_off_patience is None) or (lr_decay_step is None), ( back_off_patience, lr_decay_step) if delay > 0: print(f'Sleep for {delay} seconds.') time.sleep(delay) data_provider = DataProvider.from_config(data_provider) data_provider.train_transform.label_encoder.initialize_labels( dataset=data_provider.db.get_dataset(data_provider.validate_set), verbose=True) data_provider.test_transform.label_encoder.initialize_labels() trainer = Trainer.from_config(trainer) trainer.model.label_mapping = [] for idx, label in sorted(data_provider.train_transform.label_encoder. inverse_label_mapping.items()): assert idx == len( trainer.model.label_mapping), (idx, label, len(trainer.model.label_mapping)) trainer.model.label_mapping.append( label.replace(', ', '__').replace(' ', '').replace('(', '_').replace( ')', '_').replace("'", '')) print('Params', sum(p.numel() for p in trainer.model.parameters())) print('CNN Params', sum(p.numel() for p in trainer.model.cnn.parameters())) if init_ckpt_path is not None: print('Load init params') state_dict = deflatten(torch.load(init_ckpt_path, map_location='cpu')['model'], maxdepth=2) trainer.model.cnn.load_state_dict(flatten(state_dict['cnn'])) trainer.model.rnn_fwd.rnn.load_state_dict(state_dict['rnn_fwd']['rnn']) trainer.model.rnn_bwd.rnn.load_state_dict(state_dict['rnn_bwd']['rnn']) # pop output layer from checkpoint param_keys = sorted(state_dict['rnn_fwd']['output_net'].keys()) layer_idx = [key.split('.')[1] for key in param_keys] last_layer_idx = layer_idx[-1] for key, layer_idx in zip(param_keys, layer_idx): if layer_idx == last_layer_idx: state_dict['rnn_fwd']['output_net'].pop(key) state_dict['rnn_bwd']['output_net'].pop(key) trainer.model.rnn_fwd.output_net.load_state_dict( state_dict['rnn_fwd']['output_net'], strict=False) trainer.model.rnn_bwd.output_net.load_state_dict( state_dict['rnn_bwd']['output_net'], strict=False) if frozen_cnn_2d_layers: print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers') trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers) if frozen_cnn_1d_layers: print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers') trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers) if filter_desed_test_clips: with (database_jsons_dir / 'desed.json').open() as fid: desed_json = json.load(fid) filter_example_ids = { clip_id.rsplit('_', maxsplit=2)[0][1:] for clip_id in (list(desed_json['datasets']['validation'].keys()) + list(desed_json['datasets']['eval_public'].keys())) } else: filter_example_ids = None train_set = data_provider.get_train_set( filter_example_ids=filter_example_ids) validate_set = data_provider.get_validate_set() if validate_set is not None: trainer.test_run(train_set, validate_set) trainer.register_validation_hook( validate_set, metric='macro_fscore_weak', maximize=True, back_off_patience=back_off_patience, n_back_off=0 if back_off_patience is None else 1, lr_update_factor=lr_decay_factor, early_stopping_patience=back_off_patience, ) breakpoints = [] if lr_rampup_steps is not None: breakpoints += [(0, 0.), (lr_rampup_steps, 1.)] if lr_decay_step is not None: breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)] if len(breakpoints) > 0: if isinstance(trainer.optimizer, dict): names = sorted(trainer.optimizer.keys()) else: names = [None] for name in names: trainer.register_hook( LRAnnealingHook( trigger=AllTrigger( (100, 'iteration'), NotTrigger( EndTrigger(breakpoints[-1][0] + 100, 'iteration')), ), breakpoints=breakpoints, unit='iteration', name=name, )) trainer.train(train_set, resume=resume, track_emissions=track_emissions) if validation_set_name is not None: tuning.run( config_updates={ 'debug': debug, 'crnn_dirs': [str(trainer.storage_dir)], 'validation_set_name': validation_set_name, 'validation_ground_truth_filepath': validation_ground_truth_filepath, 'eval_set_name': eval_set_name, 'eval_ground_truth_filepath': eval_ground_truth_filepath, 'data_provider': { 'test_fetcher': { 'batch_size': data_provider.train_fetcher.batch_size, } }, })
def main(_run, storage_dir, debug, weak_label_crnn_hyper_params_dir, weak_label_crnn_dirs, weak_label_crnn_checkpoints, strong_label_crnn_dirs, strong_label_crnn_checkpoints, data_provider, validation_set_name, validation_ground_truth_filepath, eval_set_name, eval_ground_truth_filepath, medfilt_lengths, device): print() print('##### Tuning #####') print() print_config(_run) print(storage_dir) storage_dir = Path(storage_dir) if not isinstance(weak_label_crnn_checkpoints, list): assert isinstance(weak_label_crnn_checkpoints, str), weak_label_crnn_checkpoints weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [ weak_label_crnn_checkpoints ] weak_label_crnns = [ weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs, weak_label_crnn_checkpoints) ] data_provider = DESEDProvider.from_config(data_provider) data_provider.test_transform.label_encoder.initialize_labels() event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping event_classes = [event_classes[i] for i in range(len(event_classes))] frame_shift = data_provider.test_transform.stft.shift frame_shift /= data_provider.audio_reader.target_sample_rate if validation_set_name == 'validation' and not validation_ground_truth_filepath: database_root = Path( data_provider.get_raw('validation')[0] ['audio_path']).parent.parent.parent.parent validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv' elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath: database_root = Path( data_provider.get_raw('eval_public')[0] ['audio_path']).parent.parent.parent.parent validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv' assert isinstance( validation_ground_truth_filepath, (str, Path)) and Path(validation_ground_truth_filepath).exists( ), validation_ground_truth_filepath dataset = data_provider.get_dataset(validation_set_name) audio_durations = { example['example_id']: example['audio_length'] for example in data_provider.db.get_dataset(validation_set_name) } timestamps = { audio_id: np.array([0., audio_durations[audio_id]]) for audio_id in audio_durations } tags, tagging_scores, _ = tagging( weak_label_crnns, dataset, device, timestamps, event_classes, weak_label_crnn_hyper_params_dir, None, None, ) collar_based_params = { 'onset_collar': .2, 'offset_collar': .2, 'offset_collar_rate': .2, } psds_scenario_1 = { 'dtc_threshold': 0.7, 'gtc_threshold': 0.7, 'cttc_threshold': None, 'alpha_ct': .0, 'alpha_st': 1., } psds_scenario_2 = { 'dtc_threshold': 0.1, 'gtc_threshold': 0.1, 'cttc_threshold': 0.3, 'alpha_ct': .5, 'alpha_st': 1., } metrics = { 'f': partial( base.f_collar, ground_truth=validation_ground_truth_filepath, return_onset_offset_bias=True, num_jobs=8, **collar_based_params, ), 'auc1': partial( base.psd_auc, ground_truth=validation_ground_truth_filepath, audio_durations=audio_durations, num_jobs=8, **psds_scenario_1, ), 'auc2': partial( base.psd_auc, ground_truth=validation_ground_truth_filepath, audio_durations=audio_durations, num_jobs=8, **psds_scenario_2, ) } if not isinstance(strong_label_crnn_checkpoints, list): assert isinstance(strong_label_crnn_checkpoints, str), strong_label_crnn_checkpoints strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [ strong_label_crnn_checkpoints ] strong_label_crnns = [ strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs, strong_label_crnn_checkpoints) ] def add_tag_condition(example): example["tag_condition"] = np.array( [tags[example_id] for example_id in example["example_id"]]) return example timestamps = np.arange(0, 10000) * frame_shift leaderboard = strong_label.crnn.tune_sound_event_detection( strong_label_crnns, dataset.map(add_tag_condition), device, timestamps, event_classes, tags, metrics, tag_masking={ 'f': True, 'auc1': '?', 'auc2': '?' }, medfilt_lengths=medfilt_lengths, ) dump_json(leaderboard['f'][1], storage_dir / f'sed_hyper_params_f.json') f, p, r, thresholds, _ = collar_based.best_fscore( scores=leaderboard['auc1'][2], ground_truth=validation_ground_truth_filepath, **collar_based_params, num_jobs=8) for event_class in thresholds: leaderboard['auc1'][1][event_class]['threshold'] = thresholds[ event_class] dump_json(leaderboard['auc1'][1], storage_dir / 'sed_hyper_params_psds1.json') f, p, r, thresholds, _ = collar_based.best_fscore( scores=leaderboard['auc2'][2], ground_truth=validation_ground_truth_filepath, **collar_based_params, num_jobs=8) for event_class in thresholds: leaderboard['auc2'][1][event_class]['threshold'] = thresholds[ event_class] dump_json(leaderboard['auc2'][1], storage_dir / 'sed_hyper_params_psds2.json') for crnn_dir in strong_label_crnn_dirs: tuning_dir = Path(crnn_dir) / 'hyper_params' os.makedirs(str(tuning_dir), exist_ok=True) (tuning_dir / storage_dir.name).symlink_to(storage_dir) print(storage_dir) if eval_set_name: evaluation.run(config_updates={ 'debug': debug, 'strong_label_crnn_hyper_params_dir': str(storage_dir), 'dataset_name': eval_set_name, 'ground_truth_filepath': eval_ground_truth_filepath, }, )
def main( _run, storage_dir, strong_label_crnn_hyper_params_dir, sed_hyper_params_name, strong_label_crnn_dirs, strong_label_crnn_checkpoints, weak_label_crnn_hyper_params_dir, weak_label_crnn_dirs, weak_label_crnn_checkpoints, device, data_provider, dataset_name, ground_truth_filepath, save_scores, save_detections, max_segment_length, segment_overlap, strong_pseudo_labeling, pseudo_widening, pseudo_labelled_dataset_name, ): print() print('##### Inference #####') print() print_config(_run) print(storage_dir) emissions_tracker = EmissionsTracker(output_dir=storage_dir, on_csv_write="update", log_level='error') emissions_tracker.start() storage_dir = Path(storage_dir) collar_based_params = { 'onset_collar': .2, 'offset_collar': .2, 'offset_collar_rate': .2, } psds_scenario_1 = { 'dtc_threshold': 0.7, 'gtc_threshold': 0.7, 'cttc_threshold': None, 'alpha_ct': .0, 'alpha_st': 1., } psds_scenario_2 = { 'dtc_threshold': 0.1, 'gtc_threshold': 0.1, 'cttc_threshold': 0.3, 'alpha_ct': .5, 'alpha_st': 1., } if not isinstance(weak_label_crnn_checkpoints, list): assert isinstance(weak_label_crnn_checkpoints, str), weak_label_crnn_checkpoints weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [ weak_label_crnn_checkpoints ] weak_label_crnns = [ weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs, weak_label_crnn_checkpoints) ] print( 'Weak Label CRNN Params', sum([ p.numel() for crnn in weak_label_crnns for p in crnn.parameters() ])) print( 'Weak Label CNN2d Params', sum([ p.numel() for crnn in weak_label_crnns for p in crnn.cnn.cnn_2d.parameters() ])) if not isinstance(strong_label_crnn_checkpoints, list): assert isinstance(strong_label_crnn_checkpoints, str), strong_label_crnn_checkpoints strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [ strong_label_crnn_checkpoints ] strong_label_crnns = [ strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs, strong_label_crnn_checkpoints) ] print( 'Strong Label CRNN Params', sum([ p.numel() for crnn in strong_label_crnns for p in crnn.parameters() ])) print( 'Strong Label CNN2d Params', sum([ p.numel() for crnn in strong_label_crnns for p in crnn.cnn.cnn_2d.parameters() ])) data_provider = DESEDProvider.from_config(data_provider) data_provider.test_transform.label_encoder.initialize_labels() event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping event_classes = [event_classes[i] for i in range(len(event_classes))] frame_shift = data_provider.test_transform.stft.shift frame_shift /= data_provider.audio_reader.target_sample_rate if not isinstance(dataset_name, list): dataset_name = [dataset_name] if ground_truth_filepath is None: ground_truth_filepath = len(dataset_name) * [ground_truth_filepath] elif isinstance(ground_truth_filepath, (str, Path)): ground_truth_filepath = [ground_truth_filepath] assert len(ground_truth_filepath) == len(dataset_name) if not isinstance(strong_pseudo_labeling, list): strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling] assert len(strong_pseudo_labeling) == len(dataset_name) if not isinstance(pseudo_labelled_dataset_name, list): pseudo_labelled_dataset_name = [pseudo_labelled_dataset_name] assert len(pseudo_labelled_dataset_name) == len(dataset_name) database = deepcopy(data_provider.db.data) for i in range(len(dataset_name)): print() print(dataset_name[i]) if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]: database_root = Path( data_provider.get_raw('eval_public')[0] ['audio_path']).parent.parent.parent.parent ground_truth_filepath[ i] = database_root / 'metadata' / 'eval' / 'public.tsv' elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]: database_root = Path( data_provider.get_raw('validation')[0] ['audio_path']).parent.parent.parent.parent ground_truth_filepath[ i] = database_root / 'metadata' / 'validation' / 'validation.tsv' dataset = data_provider.get_dataset(dataset_name[i]) audio_durations = { example['example_id']: example['audio_length'] for example in data_provider.db.get_dataset(dataset_name[i]) } score_storage_dir = storage_dir / 'scores' / dataset_name[i] detection_storage_dir = storage_dir / 'detections' / dataset_name[i] if max_segment_length is None: timestamps = { audio_id: np.array([0., audio_durations[audio_id]]) for audio_id in audio_durations } else: timestamps = {} for audio_id in audio_durations: ts = np.arange( (2 + max_segment_length) * frame_shift, audio_durations[audio_id], (max_segment_length - segment_overlap) * frame_shift) timestamps[audio_id] = np.concatenate( ([0.], ts - segment_overlap / 2 * frame_shift, [audio_durations[audio_id]])) if max_segment_length is not None: dataset = dataset.map( partial(segment_batch, max_length=max_segment_length, overlap=segment_overlap)).unbatch() tags, tagging_scores, _ = tagging( weak_label_crnns, dataset, device, timestamps, event_classes, weak_label_crnn_hyper_params_dir, None, None, ) def add_tag_condition(example): example["tag_condition"] = np.array( [tags[example_id] for example_id in example["example_id"]]) return example dataset = dataset.map(add_tag_condition) timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6) if not isinstance(sed_hyper_params_name, (list, tuple)): sed_hyper_params_name = [sed_hyper_params_name] events, sed_results = sound_event_detection( strong_label_crnns, dataset, device, timestamps, event_classes, tags, strong_label_crnn_hyper_params_dir, sed_hyper_params_name, ground_truth_filepath[i], audio_durations, collar_based_params, [psds_scenario_1, psds_scenario_2], max_segment_length=max_segment_length, segment_overlap=segment_overlap, pseudo_widening=pseudo_widening, score_storage_dir=[ score_storage_dir / name for name in sed_hyper_params_name ] if save_scores else None, detection_storage_dir=[ detection_storage_dir / name for name in sed_hyper_params_name ] if save_detections else None, ) for j, sed_results_j in enumerate(sed_results): if sed_results_j: dump_json( sed_results_j, storage_dir / f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json' ) if strong_pseudo_labeling[i]: database['datasets'][ pseudo_labelled_dataset_name[i]] = base.pseudo_label( database['datasets'][dataset_name[i]], event_classes, False, False, strong_pseudo_labeling[i], None, None, events[0], ) with (storage_dir / f'{dataset_name[i]}_pseudo_labeled.tsv').open('w') as fid: fid.write('filename\tonset\toffset\tevent_label\n') for key, event_list in events[0].items(): if len(event_list) == 0: fid.write(f'{key}.wav\t\t\t\n') for t_on, t_off, event_label in event_list: fid.write( f'{key}.wav\t{t_on}\t{t_off}\t{event_label}\n') if any(strong_pseudo_labeling): dump_json( database, storage_dir / Path(data_provider.json_path).name, create_path=True, indent=4, ensure_ascii=False, ) inference_dir = Path(strong_label_crnn_hyper_params_dir) / 'inference' os.makedirs(str(inference_dir), exist_ok=True) (inference_dir / storage_dir.name).symlink_to(storage_dir) emissions_tracker.stop() print(storage_dir)
def main(_run, model_path, load_ckpt, batch_size, device, store_misclassified): if IS_MASTER: commands.print_config(_run) model_path = Path(model_path) eval_dir = get_new_subdir(model_path / 'eval', id_naming='time', consider_mpi=True) # perform evaluation on a sub-set (10%) of the dataset used for training config = load_json(model_path / 'config.json') database_json = config['database_json'] dataset = config['dataset'] model = pt.Model.from_storage_dir(model_path, checkpoint_name=load_ckpt, consider_mpi=True) model.to(device) # Turn on evaluation mode for, e.g., BatchNorm and Dropout modules model.eval() _, _, test_set = get_datasets(model_path, database_json, dataset, batch_size, return_indexable=device == 'cpu') with torch.no_grad(): summary = dict(misclassified_examples=dict(), correct_classified_examples=dict(), hits=list()) for batch in split_managed(test_set, is_indexable=device == 'cpu', progress_bar=True, allow_single_worker=True): output = model(pt.data.example_to_device(batch, device)) prediction = torch.argmax(output, dim=-1).cpu().numpy() confidence = torch.softmax(output, dim=-1).max(dim=-1).values.cpu()\ .numpy() label = np.array(batch['speaker_id']) hits = (label == prediction).astype('bool') summary['hits'].extend(hits.tolist()) summary['misclassified_examples'].update({ k: { 'true_label': v1, 'predicted_label': v2, 'audio_path': v3, 'confidence': f'{v4:.2%}', } for k, v1, v2, v3, v4 in zip( np.array(batch['example_id'])[~hits], label[~hits], prediction[~hits], np.array(batch['audio_path'])[~hits], confidence[~hits]) }) # for each correct predicted label, collect the audio paths correct_classified = summary['correct_classified_examples'] summary['correct_classified_examples'].update({ k: correct_classified[k] + [v] if k in correct_classified.keys() else [v] for k, v in zip(prediction[hits], np.array(batch['audio_path'])[hits]) }) summary_list = COMM.gather(summary, root=MASTER) if IS_MASTER: print(f'\nlen(summary_list): {len(summary_list)}') if len(summary_list) > 1: summary = dict( misclassified_examples=dict(), correct_classified_examples=dict(), hits=list(), ) for partial_summary in summary_list: summary['hits'].extend(partial_summary['hits']) summary['misclassified_examples'].update( partial_summary['misclassified_examples']) for label, audio_path_list in \ partial_summary['correct_classified_examples'].items(): summary['correct_classified_examples'].update({ label: summary['correct_classified_examples'][label] + audio_path_list if label in summary['correct_classified_examples'].keys() else audio_path_list }) hits = summary['hits'] misclassified_examples = summary['misclassified_examples'] correct_classified_examples = summary['correct_classified_examples'] accuracy = np.array(hits).astype('float').mean() if store_misclassified: misclassified_dir = eval_dir / 'misclassified_examples' for example_id, v in misclassified_examples.items(): label, prediction_label, audio_path, _ = v.values() try: predicted_speaker_audio_path = \ correct_classified_examples[prediction_label][0] example_dir = \ misclassified_dir / f'{example_id}_{label}_{prediction_label}' example_dir.mkdir(parents=True) os.symlink(audio_path, example_dir / 'example.wav') os.symlink(predicted_speaker_audio_path, example_dir / 'predicted_speaker_example.wav') except KeyError: warnings.warn( 'There were no correctly predicted inputs from speaker ' f'with speaker label {prediction_label}') outputs = dict( accuracy=f'{accuracy:.2%} ({np.sum(hits)}/{len(hits)})', misclassifications=misclassified_examples, ) print(f'Speaker classification accuracy on test set: {accuracy:.2%}') print(f'Wrote results to {eval_dir / "results.json"}') dump_json(outputs, eval_dir / 'results.json')
def main(_run, storage_dir, debug, crnn_dirs, crnn_checkpoints, data_provider, validation_set_name, validation_ground_truth_filepath, eval_set_name, eval_ground_truth_filepath, boundaries_filter_lengths, tune_detection_scenario_1, detection_window_lengths_scenario_1, detection_window_shift_scenario_1, detection_medfilt_lengths_scenario_1, tune_detection_scenario_2, detection_window_lengths_scenario_2, detection_window_shift_scenario_2, detection_medfilt_lengths_scenario_2, device): print() print('##### Tuning #####') print() print_config(_run) print(storage_dir) emissions_tracker = EmissionsTracker(output_dir=storage_dir, on_csv_write="update", log_level='error') emissions_tracker.start() storage_dir = Path(storage_dir) boundaries_collar_based_params = { 'onset_collar': .5, 'offset_collar': .5, 'offset_collar_rate': .0, 'min_precision': .8, } collar_based_params = { 'onset_collar': .2, 'offset_collar': .2, 'offset_collar_rate': .2, } psds_scenario_1 = { 'dtc_threshold': 0.7, 'gtc_threshold': 0.7, 'cttc_threshold': None, 'alpha_ct': .0, 'alpha_st': 1., } psds_scenario_2 = { 'dtc_threshold': 0.1, 'gtc_threshold': 0.1, 'cttc_threshold': 0.3, 'alpha_ct': .5, 'alpha_st': 1., } if not isinstance(crnn_checkpoints, list): assert isinstance(crnn_checkpoints, str), crnn_checkpoints crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints] crnns = [ weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints) ] data_provider = DataProvider.from_config(data_provider) data_provider.test_transform.label_encoder.initialize_labels() event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping event_classes = [event_classes[i] for i in range(len(event_classes))] frame_shift = data_provider.test_transform.stft.shift frame_shift /= data_provider.audio_reader.target_sample_rate if validation_set_name == 'validation' and not validation_ground_truth_filepath: database_root = Path( data_provider.get_raw('validation')[0] ['audio_path']).parent.parent.parent.parent validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv' elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath: database_root = Path( data_provider.get_raw('eval_public')[0] ['audio_path']).parent.parent.parent.parent validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv' assert isinstance( validation_ground_truth_filepath, (str, Path)) and Path(validation_ground_truth_filepath).exists( ), validation_ground_truth_filepath dataset = data_provider.get_dataset(validation_set_name) audio_durations = { example['example_id']: example['audio_length'] for example in data_provider.db.get_dataset(validation_set_name) } timestamps = { audio_id: np.array([0., audio_durations[audio_id]]) for audio_id in audio_durations } metrics = { 'f': partial(base.f_tag, ground_truth=validation_ground_truth_filepath, num_jobs=8) } leaderboard = weak_label.crnn.tune_tagging(crnns, dataset, device, timestamps, event_classes, metrics, storage_dir=storage_dir) _, hyper_params, tagging_scores = leaderboard['f'] tagging_thresholds = np.array([ hyper_params[event_class]['threshold'] for event_class in event_classes ]) tags = { audio_id: tagging_scores[audio_id][event_classes].to_numpy() > tagging_thresholds for audio_id in tagging_scores } boundaries_ground_truth = base.boundaries_from_events( validation_ground_truth_filepath) timestamps = np.arange(0, 10000) * frame_shift metrics = { 'f': partial( base.f_collar, ground_truth=boundaries_ground_truth, return_onset_offset_bias=True, num_jobs=8, **boundaries_collar_based_params, ), } weak_label.crnn.tune_boundary_detection( crnns, dataset, device, timestamps, event_classes, tags, metrics, tag_masking=True, stepfilt_lengths=boundaries_filter_lengths, storage_dir=storage_dir) if tune_detection_scenario_1: metrics = { 'f': partial( base.f_collar, ground_truth=validation_ground_truth_filepath, return_onset_offset_bias=True, num_jobs=8, **collar_based_params, ), 'auc': partial( base.psd_auc, ground_truth=validation_ground_truth_filepath, audio_durations=audio_durations, num_jobs=8, **psds_scenario_1, ), } leaderboard = weak_label.crnn.tune_sound_event_detection( crnns, dataset, device, timestamps, event_classes, tags, metrics, tag_masking={ 'f': True, 'auc': '?' }, window_lengths=detection_window_lengths_scenario_1, window_shift=detection_window_shift_scenario_1, medfilt_lengths=detection_medfilt_lengths_scenario_1, ) dump_json(leaderboard['f'][1], storage_dir / f'sed_hyper_params_f.json') f, p, r, thresholds, _ = collar_based.best_fscore( scores=leaderboard['auc'][2], ground_truth=validation_ground_truth_filepath, **collar_based_params, num_jobs=8) for event_class in thresholds: leaderboard['auc'][1][event_class]['threshold'] = thresholds[ event_class] dump_json(leaderboard['auc'][1], storage_dir / 'sed_hyper_params_psds1.json') if tune_detection_scenario_2: metrics = { 'auc': partial( base.psd_auc, ground_truth=validation_ground_truth_filepath, audio_durations=audio_durations, num_jobs=8, **psds_scenario_2, ) } leaderboard = weak_label.crnn.tune_sound_event_detection( crnns, dataset, device, timestamps, event_classes, tags, metrics, tag_masking=False, window_lengths=detection_window_lengths_scenario_2, window_shift=detection_window_shift_scenario_2, medfilt_lengths=detection_medfilt_lengths_scenario_2, ) f, p, r, thresholds, _ = collar_based.best_fscore( scores=leaderboard['auc'][2], ground_truth=validation_ground_truth_filepath, **collar_based_params, num_jobs=8) for event_class in thresholds: leaderboard['auc'][1][event_class]['threshold'] = thresholds[ event_class] dump_json(leaderboard['auc'][1], storage_dir / 'sed_hyper_params_psds2.json') for crnn_dir in crnn_dirs: tuning_dir = Path(crnn_dir) / 'hyper_params' os.makedirs(str(tuning_dir), exist_ok=True) (tuning_dir / storage_dir.name).symlink_to(storage_dir) emissions_tracker.stop() print(storage_dir) if eval_set_name: if tune_detection_scenario_1: evaluation.run(config_updates={ 'debug': debug, 'hyper_params_dir': str(storage_dir), 'dataset_name': eval_set_name, 'ground_truth_filepath': eval_ground_truth_filepath, }, ) if tune_detection_scenario_2: evaluation.run(config_updates={ 'debug': debug, 'hyper_params_dir': str(storage_dir), 'dataset_name': eval_set_name, 'ground_truth_filepath': eval_ground_truth_filepath, 'sed_hyper_params_name': 'psds2', }, )
def train(logdir, device, iterations, resume_iteration, checkpoint_interval, train_on, batch_size, sequence_length, model_complexity, learning_rate, learning_rate_decay_steps, learning_rate_decay_rate, leave_one_out, clip_gradient_norm, validation_length, validation_interval): print_config(ex.current_run) os.makedirs(logdir, exist_ok=True) writer = SummaryWriter(logdir) train_groups, validation_groups = ['train'], ['validation'] if leave_one_out is not None: all_years = { '2004', '2006', '2008', '2009', '2011', '2013', '2014', '2015', '2017' } train_groups = list(all_years - {str(leave_one_out)}) validation_groups = [str(leave_one_out)] if train_on == 'MAESTRO': dataset = MAESTRO(groups=train_groups, sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=validation_groups, sequence_length=sequence_length) elif train_on == 'MAPS': dataset = MAPS(groups=[ 'AkPnBcht', 'AkPnBsdf', 'AkPnCGdD', 'AkPnStgb', 'SptkBGAm', 'SptkBGCl', 'StbgTGd2' ], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=['ENSTDkAm', 'ENSTDkCl'], sequence_length=validation_length) elif train_on == 'WJD': dataset = WJD(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=[''], sequence_length=validation_length) elif train_on == 'DMJ': dataset = DMJ(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=[''], sequence_length=validation_length) elif train_on == 'MAPS_DMJ': dataset = MAPS_DMJ(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=[''], sequence_length=validation_length) elif train_on == 'MILLS': dataset = MILLS(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=[''], sequence_length=validation_length) elif train_on == 'MAESTRO_JAZZ': dataset = MAESTRO_JAZZ(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=validation_groups, sequence_length=validation_length) elif train_on == 'MAESTRO_JAZZ_AUGMENTED': dataset = MAESTRO_JAZZ_AUGMENTED(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=validation_groups, sequence_length=validation_length) elif train_on == 'DMJ_DMJOFFSHIFT': dataset = DMJ_DMJOFFSHIFT(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=validation_groups, sequence_length=validation_length) elif train_on == 'MAPS_DMJ_DMJOFFSHIFT': dataset = MAPS_DMJ_DMJOFFSHIFT(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=validation_groups, sequence_length=validation_length) elif train_on == 'JAZZ_TRAIN': dataset = JAZZ_TRAIN(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=validation_groups, sequence_length=validation_length) elif train_on == 'MAESTRO_JAZZ_TRAIN': dataset = MAESTRO_JAZZ_TRAIN(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=validation_groups, sequence_length=validation_length) elif train_on == 'MAPS_JAZZ_TRAIN': dataset = MAPS_JAZZ_TRAIN(groups=[''], sequence_length=sequence_length) validation_dataset = JAZZ_EVALUATE(groups=validation_groups, sequence_length=validation_length) else: pass loader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True) if resume_iteration is None: model = OnsetsAndFrames(N_MELS, MAX_MIDI - MIN_MIDI + 1, model_complexity).to(device) optimizer = torch.optim.Adam(model.parameters(), learning_rate) resume_iteration = 0 else: model_path = os.path.join(logdir, f'model-{resume_iteration}.pt') model = torch.load(model_path) optimizer = torch.optim.Adam(model.parameters(), learning_rate) optimizer.load_state_dict( torch.load(os.path.join(logdir, 'last-optimizer-state.pt'))) # i = 0 # for child in model.children(): # print("Child {}:".format(i)) # print(child) # if i == 1: # print("offset model") # count_parameters(child) # i += 1 # Comment this out to run the model train loop from scratch. Else, this will be finetuning the model model.freeze_stacks() count_parameters(model) summary(model) scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate) loop = tqdm(range(resume_iteration + 1, iterations + 1)) for i, batch in zip(loop, cycle(loader)): predictions, losses = model.run_on_batch(batch) loss = sum(losses.values()) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() if clip_gradient_norm: clip_grad_norm_(model.parameters(), clip_gradient_norm) for key, value in {'loss': loss, **losses}.items(): writer.add_scalar(key, value.item(), global_step=i) if i % validation_interval == 0: model.eval() with torch.no_grad(): for key, value in evaluate_old(validation_dataset, model).items(): writer.add_scalar('validation/' + key.replace(' ', '_'), np.mean(value), global_step=i) model.train() if i % checkpoint_interval == 0: torch.save(model, os.path.join(logdir, f'model-{i}.pt')) torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt'))
def train( _run, debug, data_provider, trainer, lr_rampup_steps, back_off_patience, lr_decay_step, lr_decay_factor, init_ckpt_path, frozen_cnn_2d_layers, frozen_cnn_1d_layers, resume, delay, validation_set_name, validation_ground_truth_filepath, weak_label_crnn_hyper_params_dir, eval_set_name, eval_ground_truth_filepath, ): print() print('##### Training #####') print() print_config(_run) assert (back_off_patience is None) or (lr_decay_step is None), ( back_off_patience, lr_decay_step) if delay > 0: print(f'Sleep for {delay} seconds.') time.sleep(delay) data_provider = DESEDProvider.from_config(data_provider) data_provider.train_transform.label_encoder.initialize_labels( dataset=data_provider.db.get_dataset(data_provider.validate_set), verbose=True) data_provider.test_transform.label_encoder.initialize_labels() trainer = Trainer.from_config(trainer) trainer.model.label_mapping = [] for idx, label in sorted(data_provider.train_transform.label_encoder. inverse_label_mapping.items()): assert idx == len( trainer.model.label_mapping), (idx, label, len(trainer.model.label_mapping)) trainer.model.label_mapping.append( label.replace(', ', '__').replace(' ', '').replace('(', '_').replace( ')', '_').replace("'", '')) print('Params', sum(p.numel() for p in trainer.model.parameters())) if init_ckpt_path is not None: print('Load init params') state_dict = deflatten(torch.load(init_ckpt_path, map_location='cpu')['model'], maxdepth=1) trainer.model.cnn.load_state_dict(state_dict['cnn']) if frozen_cnn_2d_layers: print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers') trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers) if frozen_cnn_1d_layers: print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers') trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers) def add_tag_condition(example): example["tag_condition"] = example["weak_targets"] return example train_set = data_provider.get_train_set().map(add_tag_condition) validate_set = data_provider.get_validate_set().map(add_tag_condition) if validate_set is not None: trainer.test_run(train_set, validate_set) trainer.register_validation_hook( validate_set, metric='macro_fscore_strong', maximize=True, ) breakpoints = [] if lr_rampup_steps is not None: breakpoints += [(0, 0.), (lr_rampup_steps, 1.)] if lr_decay_step is not None: breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)] if len(breakpoints) > 0: if isinstance(trainer.optimizer, dict): names = sorted(trainer.optimizer.keys()) else: names = [None] for name in names: trainer.register_hook( LRAnnealingHook( trigger=AllTrigger( (100, 'iteration'), NotTrigger( EndTrigger(breakpoints[-1][0] + 100, 'iteration')), ), breakpoints=breakpoints, unit='iteration', name=name, )) trainer.train(train_set, resume=resume) if validation_set_name: tuning.run( config_updates={ 'debug': debug, 'weak_label_crnn_hyper_params_dir': weak_label_crnn_hyper_params_dir, 'strong_label_crnn_dirs': [str(trainer.storage_dir)], 'validation_set_name': validation_set_name, 'validation_ground_truth_filepath': validation_ground_truth_filepath, 'eval_set_name': eval_set_name, 'eval_ground_truth_filepath': eval_ground_truth_filepath, })
def print_config_option(args, run): """Always print the configuration first.""" print_config(run) print("-" * 79)
def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size, fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs, train_only_in_receptive_field, _run, use_ulaw, train_with_soft_target_stdev): if run_dir is None: if not os.path.exists("models"): os.mkdir("models") run_dir = os.path.join('models', datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S')) _config['run_dir'] = run_dir print_config(_run) _log.info('Running with seed %d' % seed) if not debug: if os.path.exists(run_dir): raise EnvironmentError('Run with seed %d already exists' % seed) os.mkdir(run_dir) checkpoint_dir = os.path.join(run_dir, 'checkpoints') json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w')) _log.info('Loading data...') data_generators, nb_examples = get_generators() _log.info('Building model...') model = build_model(fragment_length) _log.info(model.summary()) optim = make_optimizer() _log.info('Compiling Model...') loss = objectives.categorical_crossentropy all_metrics = [ metrics.categorical_accuracy, categorical_mean_squared_error ] if train_with_soft_target_stdev: loss = make_targets_soft(loss) if train_only_in_receptive_field: loss = skip_out_of_receptive_field(loss) all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics] model.compile(optimizer=optim, loss=loss, metrics=all_metrics) # TODO: Consider gradient weighting making last outputs more important. tictoc = strftime("%a_%d_%b_%Y_%H_%M_%S", gmtime()) directory_name = tictoc log_dir = 'wavenet_' + directory_name os.mkdir(log_dir) tensorboard = TensorBoard(log_dir=log_dir) callbacks = [ tensorboard, ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1), EarlyStopping(patience=early_stopping_patience, verbose=1), ] if not debug: callbacks.extend([ ModelCheckpoint(os.path.join(checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'), save_best_only=True), CSVLogger(os.path.join(run_dir, 'history.csv')), ]) if not debug: os.mkdir(checkpoint_dir) _log.info('Starting Training...') print("nb_examples['train'] {0}".format(nb_examples['train'])) print("nb_examples['test'] {0}".format(nb_examples['test'])) model.fit_generator(data_generators['train'], steps_per_epoch=nb_examples['train'] // batch_size, epochs=nb_epoch, validation_data=data_generators['test'], validation_steps=nb_examples['test'] // batch_size, callbacks=callbacks, verbose=keras_verbose)
def main( _run, out, mask_estimator, Observation, beamformer, postfilter, normalize_audio=True, ): if dlp_mpi.IS_MASTER: from sacred.commands import print_config print_config(_run) ds = get_dataset() data = [] out = Path(out) for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True): if mask_estimator is None: mask = None elif mask_estimator == 'cacgmm': mask = get_mask_from_cacgmm(ex) else: mask = get_mask_from_oracle(ex, mask_estimator) metric, score = get_scores( ex, mask, Observation=Observation, beamformer=beamformer, postfilter=postfilter, ) est0, est1 = metric.speech_prediction_selection dump_audio(est0, out / ex['dataset'] / f"{ex['example_id']}_0.wav", normalize=normalize_audio) dump_audio(est1, out / ex['dataset'] / f"{ex['example_id']}_1.wav", normalize=normalize_audio) data.append( dict( example_id=ex['example_id'], value=score, dataset=ex['dataset'], )) # print(score, repr(score)) data = dlp_mpi.gather(data) if dlp_mpi.IS_MASTER: data = [entry for worker_data in data for entry in worker_data] data = { # itertools.groupby expect an order dataset: list(subset) for dataset, subset in from_list(data).groupby( lambda ex: ex['dataset']).items() } for dataset, sub_data in data.items(): print(f'Write details to {out}.') dump_json(sub_data, out / f'{dataset}_scores.json') for dataset, sub_data in data.items(): summary = {} for k in sub_data[0]['value'].keys(): m = np.mean([d['value'][k] for d in sub_data]) print(dataset, k, m) summary[k] = m dump_json(summary, out / f'{dataset}_summary.json')
def train(_run, model, device, lr, gradient_clipping, weight_decay, swa_start, swa_freq, swa_lr, summary_interval, validation_interval, max_steps): print_config(_run) os.makedirs(storage_dir / 'checkpoints', exist_ok=True) train_iter, validate_iter, batch_norm_tuning_iter = get_datasets() model = CRNN( cnn_2d=CNN2d(**model['cnn_2d']), cnn_1d=CNN1d(**model['cnn_1d']), enc=GRU(**model['enc']), fcn=fully_connected_stack(**model['fcn']), fcn_noisy=None if model['fcn_noisy'] is None else fully_connected_stack(**model['fcn_noisy']), ) print(sum(p.numel() for p in model.parameters() if p.requires_grad)) model = model.to(device) model.train() optimizer = Adam(tuple(model.parameters()), lr=lr, weight_decay=weight_decay) if swa_start is not None: optimizer = SWA(optimizer, swa_start=swa_start, swa_freq=swa_freq, swa_lr=swa_lr) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False # Summary summary_writer = tensorboardX.SummaryWriter(str(storage_dir)) def get_empty_summary(): return dict( scalars=defaultdict(list), histograms=defaultdict(list), images=dict(), ) def update_summary(review, summary): review['scalars']['loss'] = review['loss'].detach() for key, value in review['scalars'].items(): if torch.is_tensor(value): value = value.cpu().data.numpy() summary['scalars'][key].extend(np.array(value).flatten().tolist()) for key, value in review['histograms'].items(): if torch.is_tensor(value): value = value.cpu().data.numpy() summary['histograms'][key].extend( np.array(value).flatten().tolist()) summary['images'] = review['images'] def dump_summary(summary, prefix, iteration): summary = model.modify_summary(summary) # write summary for key, value in summary['scalars'].items(): summary_writer.add_scalar(f'{prefix}/{key}', np.mean(value), iteration) for key, values in summary['histograms'].items(): summary_writer.add_histogram(f'{prefix}/{key}', np.array(values), iteration) for key, image in summary['images'].items(): summary_writer.add_image(f'{prefix}/{key}', image, iteration) return defaultdict(list) # Training loop train_summary = get_empty_summary() i = 0 while i < max_steps: for batch in train_iter: optimizer.zero_grad() # forward batch = batch_to_device(batch, device=device) model_out = model(batch) # backward review = model.review(batch, model_out) review['loss'].backward() review['histograms']['grad_norm'] = torch.nn.utils.clip_grad_norm_( tuple(model.parameters()), gradient_clipping) optimizer.step() # update summary update_summary(review, train_summary) i += 1 if i % summary_interval == 0: dump_summary(train_summary, 'training', i) train_summary = get_empty_summary() if i % validation_interval == 0 and validate_iter is not None: print('Starting Validation') model.eval() validate_summary = get_empty_summary() with torch.no_grad(): for batch in validate_iter: batch = batch_to_device(batch, device=device) model_out = model(batch) review = model.review(batch, model_out) update_summary(review, validate_summary) dump_summary(validate_summary, 'validation', i) print('Finished Validation') model.train() if i >= max_steps: break # finalize if swa_start is not None: optimizer.swap_swa_sgd() batch_norm_update(model, batch_norm_tuning_iter, feature_key='features', device=device) torch.save(model.state_dict(), storage_dir / 'checkpoints' / 'ckpt_final.pth')
def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size, fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs, train_only_in_receptive_field, _run, use_ulaw, train_with_soft_target_stdev): if run_dir is None: if not os.path.exists("models"): os.makedirs("models", exist_ok=True) run_dir = os.path.join('models', datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S')) _config['run_dir'] = run_dir print_config(_run) _log.info('Running with seed %d' % seed) if not debug: if os.path.exists(run_dir): raise EnvironmentError('Run with seed %d already exists' % seed) os.makedirs(run_dir, exist_ok=True) checkpoint_dir = os.path.join(run_dir, 'checkpoints') json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w')) _log.info('Loading data...') data_generators, nb_examples = get_generators() _log.info('Building model...') model = build_model(fragment_length) _log.info(model.summary()) optim = make_optimizer() _log.info('Compiling Model...') loss = losses.categorical_crossentropy all_metrics = [ metrics.categorical_accuracy, categorical_mean_squared_error ] if train_with_soft_target_stdev: loss = make_targets_soft(loss) if train_only_in_receptive_field: loss = skip_out_of_receptive_field(loss) all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics] model.compile(optimizer=optim, loss=loss, metrics=all_metrics) # TODO: Consider gradient weighting making last outputs more important. tictoc = strftime("%a_%d_%b_%Y_%H_%M_%S", gmtime()) directory_name = tictoc log_dir = 'tensorboard\\wavenet_' + directory_name os.makedirs(log_dir, exist_ok=True) tensorboard = TensorBoard(log_dir=log_dir) callbacks = [ tensorboard, ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1), EarlyStopping(patience=early_stopping_patience, verbose=1), ] if not debug: os.makedirs(checkpoint_dir, exist_ok=True) callbacks.extend([ ModelCheckpoint(os.path.join(checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'), save_best_only=True), CSVLogger(os.path.join(run_dir, 'history.csv')), ]) _log.info('Starting Training...') print("nb_examples['train'] {0}".format(nb_examples['train'])) print("nb_examples['test'] {0}".format(nb_examples['test'])) model.fit_generator(data_generators['train'], steps_per_epoch=nb_examples['train'] // batch_size, epochs=nb_epoch, validation_data=data_generators['test'], validation_steps=nb_examples['test'] // batch_size, callbacks=callbacks, verbose=keras_verbose)
def main(_run, exp_dir, storage_dir, database_json, ckpt_name, num_workers, batch_size, max_padding_rate, device): commands.print_config(_run) exp_dir = Path(exp_dir) storage_dir = Path(storage_dir) config = load_json(exp_dir / 'config.json') model = Model.from_storage_dir(exp_dir, consider_mpi=True, checkpoint_name=ckpt_name) model.to(device) model.eval() _, validation_data, test_data = get_datasets( database_json=database_json, min_signal_length=1.5, audio_reader=config['audio_reader'], stft=config['stft'], num_workers=num_workers, batch_size=batch_size, max_padding_rate=max_padding_rate, storage_dir=exp_dir, ) outputs = [] with torch.no_grad(): for example in tqdm(validation_data): example = model.example_to_device(example, device) (y, seq_len), _ = model(example) y = Mean(axis=-1)(y, seq_len) outputs.append(( y.cpu().detach().numpy(), example['events'].cpu().detach().numpy(), )) scores, targets = list(zip(*outputs)) scores = np.concatenate(scores) targets = np.concatenate(targets) thresholds, f1 = instance_based.get_optimal_thresholds(targets, scores, metric='f1') decisions = scores > thresholds f1, p, r = instance_based.fscore(targets, decisions, event_wise=True) ap = metrics.average_precision_score(targets, scores, None) auc = metrics.roc_auc_score(targets, scores, None) pos_class_indices, precision_at_hits = instance_based.positive_class_precisions( targets, scores) lwlrap, per_class_lwlrap, weight_per_class = instance_based.lwlrap_from_precisions( precision_at_hits, pos_class_indices, num_classes=targets.shape[1]) overall_results = { 'validation': { 'mF1': np.mean(f1), 'mP': np.mean(p), 'mR': np.mean(r), 'mAP': np.mean(ap), 'mAUC': np.mean(auc), 'lwlrap': lwlrap, } } event_validation_results = {} labels = load_json(exp_dir / 'events.json') for i, label in enumerate(labels): event_validation_results[label] = { 'F1': f1[i], 'P': p[i], 'R': r[i], 'AP': ap[i], 'AUC': auc[i], 'lwlrap': per_class_lwlrap[i], } outputs = [] with torch.no_grad(): for example in tqdm(test_data): example = model.example_to_device(example, device) (y, seq_len), _ = model(example) y = Mean(axis=-1)(y, seq_len) outputs.append(( example['example_id'], y.cpu().detach().numpy(), example['events'].cpu().detach().numpy(), )) example_ids, scores, targets = list(zip(*outputs)) example_ids = np.concatenate(example_ids).tolist() scores = np.concatenate(scores) targets = np.concatenate(targets) decisions = scores > thresholds f1, p, r = instance_based.fscore(targets, decisions, event_wise=True) ap = metrics.average_precision_score(targets, scores, None) auc = metrics.roc_auc_score(targets, scores, None) pos_class_indices, precision_at_hits = instance_based.positive_class_precisions( targets, scores) lwlrap, per_class_lwlrap, weight_per_class = instance_based.lwlrap_from_precisions( precision_at_hits, pos_class_indices, num_classes=targets.shape[1]) overall_results['test'] = { 'mF1': np.mean(f1), 'mP': np.mean(p), 'mR': np.mean(r), 'mAP': np.mean(ap), 'mAUC': np.mean(auc), 'lwlrap': lwlrap, } dump_json(overall_results, storage_dir / 'overall.json', indent=4, sort_keys=False) event_results = {} for i, label in sorted(enumerate(labels), key=lambda x: ap[x[0]], reverse=True): event_results[label] = { 'validation': event_validation_results[label], 'test': { 'F1': f1[i], 'P': p[i], 'R': r[i], 'AP': ap[i], 'AUC': auc[i], 'lwlrap': per_class_lwlrap[i], }, } dump_json(event_results, storage_dir / 'event_wise.json', indent=4, sort_keys=False) fp = np.argwhere(decisions * (1 - targets)) dump_json(sorted([(example_ids[n], labels[i]) for n, i in fp]), storage_dir / 'fp.json', indent=4, sort_keys=False) fn = np.argwhere((1 - decisions) * targets) dump_json(sorted([(example_ids[n], labels[i]) for n, i in fn]), storage_dir / 'fn.json', indent=4, sort_keys=False) pprint(overall_results)
def apply(cls, args, run): print_config(run) print('-' * 79)
def main( _run, storage_dir, hyper_params_dir, sed_hyper_params_name, crnn_dirs, crnn_checkpoints, device, data_provider, dataset_name, ground_truth_filepath, save_scores, save_detections, max_segment_length, segment_overlap, weak_pseudo_labeling, boundary_pseudo_labeling, strong_pseudo_labeling, pseudo_widening, pseudo_labeled_dataset_name, ): print() print('##### Inference #####') print() print_config(_run) print(storage_dir) emissions_tracker = EmissionsTracker(output_dir=storage_dir, on_csv_write="update", log_level='error') emissions_tracker.start() storage_dir = Path(storage_dir) boundary_collar_based_params = { 'onset_collar': .5, 'offset_collar': .5, 'offset_collar_rate': .0, } collar_based_params = { 'onset_collar': .2, 'offset_collar': .2, 'offset_collar_rate': .2, } psds_scenario_1 = { 'dtc_threshold': 0.7, 'gtc_threshold': 0.7, 'cttc_threshold': None, 'alpha_ct': .0, 'alpha_st': 1., } psds_scenario_2 = { 'dtc_threshold': 0.1, 'gtc_threshold': 0.1, 'cttc_threshold': 0.3, 'alpha_ct': .5, 'alpha_st': 1., } if not isinstance(crnn_checkpoints, list): assert isinstance(crnn_checkpoints, str), crnn_checkpoints crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints] crnns = [ CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints) ] print('Params', sum([p.numel() for crnn in crnns for p in crnn.parameters()])) print( 'CNN2d Params', sum([ p.numel() for crnn in crnns for p in crnn.cnn.cnn_2d.parameters() ])) data_provider = DataProvider.from_config(data_provider) data_provider.test_transform.label_encoder.initialize_labels() event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping event_classes = [event_classes[i] for i in range(len(event_classes))] frame_shift = data_provider.test_transform.stft.shift frame_shift /= data_provider.audio_reader.target_sample_rate if not isinstance(dataset_name, list): dataset_name = [dataset_name] if ground_truth_filepath is None: ground_truth_filepath = len(dataset_name) * [ground_truth_filepath] elif isinstance(ground_truth_filepath, (str, Path)): ground_truth_filepath = [ground_truth_filepath] assert len(ground_truth_filepath) == len(dataset_name) if not isinstance(weak_pseudo_labeling, list): weak_pseudo_labeling = len(dataset_name) * [weak_pseudo_labeling] assert len(weak_pseudo_labeling) == len(dataset_name) if not isinstance(boundary_pseudo_labeling, list): boundary_pseudo_labeling = len(dataset_name) * [ boundary_pseudo_labeling ] assert len(boundary_pseudo_labeling) == len(dataset_name) if not isinstance(strong_pseudo_labeling, list): strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling] assert len(strong_pseudo_labeling) == len(dataset_name) if not isinstance(pseudo_labeled_dataset_name, list): pseudo_labeled_dataset_name = [pseudo_labeled_dataset_name] assert len(pseudo_labeled_dataset_name) == len(dataset_name) database = deepcopy(data_provider.db.data) for i in range(len(dataset_name)): print() print(dataset_name[i]) if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]: database_root = Path( data_provider.get_raw('eval_public')[0] ['audio_path']).parent.parent.parent.parent ground_truth_filepath[ i] = database_root / 'metadata' / 'eval' / 'public.tsv' elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]: database_root = Path( data_provider.get_raw('validation')[0] ['audio_path']).parent.parent.parent.parent ground_truth_filepath[ i] = database_root / 'metadata' / 'validation' / 'validation.tsv' dataset = data_provider.get_dataset(dataset_name[i]) audio_durations = { example['example_id']: example['audio_length'] for example in data_provider.db.get_dataset(dataset_name[i]) } score_storage_dir = storage_dir / 'scores' / dataset_name[i] detection_storage_dir = storage_dir / 'detections' / dataset_name[i] if max_segment_length is None: timestamps = { audio_id: np.array([0., audio_durations[audio_id]]) for audio_id in audio_durations } else: timestamps = {} for audio_id in audio_durations: ts = np.arange(0, audio_durations[audio_id], (max_segment_length - segment_overlap) * frame_shift) timestamps[audio_id] = np.concatenate( (ts, [audio_durations[audio_id]])) tags, tagging_scores, tagging_results = tagging( crnns, dataset, device, timestamps, event_classes, hyper_params_dir, ground_truth_filepath[i], audio_durations, [psds_scenario_1, psds_scenario_2], max_segment_length=max_segment_length, segment_overlap=segment_overlap, ) if tagging_results: dump_json(tagging_results, storage_dir / f'tagging_results_{dataset_name[i]}.json') timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6) if ground_truth_filepath[i] is not None or boundary_pseudo_labeling[i]: boundaries, boundaries_detection_results = boundaries_detection( crnns, dataset, device, timestamps, event_classes, tags, hyper_params_dir, ground_truth_filepath[i], boundary_collar_based_params, max_segment_length=max_segment_length, segment_overlap=segment_overlap, pseudo_widening=pseudo_widening, ) if boundaries_detection_results: dump_json( boundaries_detection_results, storage_dir / f'boundaries_detection_results_{dataset_name[i]}.json') else: boundaries = {} if not isinstance(sed_hyper_params_name, (list, tuple)): sed_hyper_params_name = [sed_hyper_params_name] if (ground_truth_filepath[i] is not None ) or strong_pseudo_labeling[i] or save_scores or save_detections: events, sed_results = sound_event_detection( crnns, dataset, device, timestamps, event_classes, tags, hyper_params_dir, sed_hyper_params_name, ground_truth_filepath[i], audio_durations, collar_based_params, [psds_scenario_1, psds_scenario_2], max_segment_length=max_segment_length, segment_overlap=segment_overlap, pseudo_widening=pseudo_widening, score_storage_dir=[ score_storage_dir / name for name in sed_hyper_params_name ] if save_scores else None, detection_storage_dir=[ detection_storage_dir / name for name in sed_hyper_params_name ] if save_detections else None, ) for j, sed_results_j in enumerate(sed_results): if sed_results_j: dump_json( sed_results_j, storage_dir / f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json' ) else: events = [{}] database['datasets'][ pseudo_labeled_dataset_name[i]] = base.pseudo_label( database['datasets'][dataset_name[i]], event_classes, weak_pseudo_labeling[i], boundary_pseudo_labeling[i], strong_pseudo_labeling[i], tags, boundaries, events[0], ) if any(weak_pseudo_labeling) or any(boundary_pseudo_labeling) or any( strong_pseudo_labeling): dump_json( database, storage_dir / Path(data_provider.json_path).name, create_path=True, indent=4, ensure_ascii=False, ) inference_dir = Path(hyper_params_dir) / 'inference' os.makedirs(str(inference_dir), exist_ok=True) (inference_dir / storage_dir.name).symlink_to(storage_dir) emissions_tracker.stop() print(storage_dir)
def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size, fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs, train_only_in_receptive_field, _run, use_ulaw, train_with_soft_target_stdev): dataset.set_multi_gpu(True) if run_dir is None: if not os.path.exists("models"): os.mkdir("models") run_dir = os.path.join( 'models', datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S')) _config['run_dir'] = run_dir print_config(_run) _log.info('Running with seed %d' % seed) if not debug: if not os.path.exists(run_dir): os.mkdir(run_dir) checkpoint_dir = os.path.join(run_dir, 'checkpoints') json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w')) _log.info('Loading data...') data_generators, nb_examples = get_generators() _log.info('Building model...') model = build_model(fragment_length) _log.info(model.summary()) optim = make_optimizer() optim = hvd.DistributedOptimizer(optim) _log.info('Compiling Model...') loss = objectives.categorical_crossentropy all_metrics = [ metrics.categorical_accuracy, categorical_mean_squared_error ] if train_with_soft_target_stdev: loss = make_targets_soft(loss) if train_only_in_receptive_field: loss = skip_out_of_receptive_field(loss) all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics] model.compile(optimizer=optim, loss=loss, metrics=all_metrics) # TODO: Consider gradient weighting making last outputs more important. callbacks = [ # Broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. hvd.callbacks.BroadcastGlobalVariablesCallback(0), # Average metrics among workers at the end of every epoch. # # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard or other metrics-based callbacks. hvd.callbacks.MetricAverageCallback(), # Using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during # the first five epochs. See https://arxiv.org/abs/1706.02677 for details. hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=1), ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1), EarlyStopping(patience=early_stopping_patience, verbose=1), ] if not debug and hvd.rank() == 0: callbacks.extend([ ModelCheckpoint(os.path.join( checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'), save_best_only=True), CSVLogger(os.path.join(run_dir, 'history.csv')), ]) if not debug: if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) _log.info('Starting Training...') model.fit_generator(data_generators['train'], nb_examples['train'] // hvd.size(), epochs=nb_epoch, validation_data=data_generators['test'], validation_steps=nb_examples['test'] // hvd.size(), callbacks=callbacks, verbose=keras_verbose if hvd.rank() == 0 else 0)