def __init__( self, json_path, target, dset, sample_rate=8000, single_channel=True, segment=4.0, nondefault_nsrc=None, normalize_audio=False, ): try: import sms_wsj # noqa except ModuleNotFoundError: import warnings warnings.warn( "Some of the functionality relies on the sms_wsj package " "downloadable from https://github.com/fgnt/sms_wsj ." "The user is encouraged to install the package") super().__init__() if target not in SMS_TARGETS.keys(): raise ValueError("Unexpected task {}, expected one of " "{}".format(target, SMS_TARGETS.keys())) # Task setting self.json_path = json_path self.target = target self.target_dict = SMS_TARGETS[target] self.single_channel = single_channel self.sample_rate = sample_rate self.normalize_audio = normalize_audio self.seg_len = None if segment is None else int(segment * sample_rate) if not nondefault_nsrc: self.n_src = self.target_dict["default_nsrc"] else: assert nondefault_nsrc >= self.target_dict["default_nsrc"] self.n_src = nondefault_nsrc self.like_test = self.seg_len is None self.dset = dset self.EPS = 1e-8 # Load json files from lazy_dataset.database import JsonDatabase db = JsonDatabase(json_path) dataset = db.get_dataset(dset) # Filter out short utterances only when segment is specified if not self.like_test: def filter_short_examples(example): num_samples = example["num_samples"]["observation"] if num_samples < self.seg_len: return False else: return True dataset = dataset.filter(filter_short_examples, lazy=False) self.dataset = dataset
def main(_run, _log, trainer, database_json, training_sets, validation_sets, audio_reader, stft, max_length_in_sec, batch_size, resume): commands.print_config(_run) trainer = Trainer.from_config(trainer) storage_dir = Path(trainer.storage_dir) storage_dir.mkdir(parents=True, exist_ok=True) commands.save_config(_run.config, _log, config_filename=str(storage_dir / 'config.json')) db = JsonDatabase(database_json) training_data = db.get_dataset(training_sets) validation_data = db.get_dataset(validation_sets) training_data = prepare_dataset(training_data, audio_reader=audio_reader, stft=stft, max_length_in_sec=max_length_in_sec, batch_size=batch_size, shuffle=True) validation_data = prepare_dataset(validation_data, audio_reader=audio_reader, stft=stft, max_length_in_sec=max_length_in_sec, batch_size=batch_size, shuffle=False) trainer.test_run(training_data, validation_data) trainer.register_validation_hook(validation_data) trainer.train(training_data, resume=resume)
def get_datasets(storage_dir, database_json, dataset, batch_size=16, return_indexable=False): db = JsonDatabase(database_json) ds = db.get_dataset(dataset) def prepare_example(example): example['audio_path'] = example['audio_path']['observation'] example['speaker_id'] = example['speaker_id'].split('-')[0] return example ds = ds.map(prepare_example) speaker_encoder = LabelEncoder(label_key='speaker_id', storage_dir=storage_dir, to_array=True) speaker_encoder.initialize_labels(dataset=ds, verbose=True) ds = ds.map(speaker_encoder) # LibriSpeech (the default database) does not share speakers across # different datasets, i.e., the datasets, e.g. clean_100 and dev_clean, have # a different set of non-overlapping speakers. To guarantee the same set of # speakers during training, validation and evaluation, we perform a split of # the train set, e.g., clean_100 or clean_360. train_set, validate_set, test_set = train_test_split(ds) training_data = prepare_dataset(train_set, batch_size, training=True) validation_data = prepare_dataset(validate_set, batch_size, training=False) test_data = prepare_dataset(test_set, batch_size, training=False, return_indexable=return_indexable) return training_data, validation_data, test_data
def prepare_and_train(_run, _log, trainer, train_dataset, validate_dataset, lr_scheduler_step, lr_scheduler_gamma, load_model_from, database_json): trainer = pt.Trainer.from_config(trainer) checkpoint_path = trainer.checkpoint_dir / 'ckpt_latest.pth' if load_model_from is not None and not checkpoint_path.is_file(): _log.info(f'Loading model weights from {load_model_from}') checkpoint = torch.load(load_model_from) trainer.model.load_state_dict(checkpoint['model']) db = JsonDatabase(database_json) # Perform a test run to check if everything works trainer.test_run( prepare_iterable_captured(db, train_dataset), prepare_iterable_captured(db, validate_dataset), ) # Register hooks and start the actual training trainer.register_validation_hook( prepare_iterable_captured(db, validate_dataset)) # Learning rate scheduler trainer.register_hook( pt.train.hooks.LRSchedulerHook( torch.optim.lr_scheduler.StepLR( trainer.optimizer.optimizer, step_size=lr_scheduler_step, gamma=lr_scheduler_gamma, ))) trainer.train(prepare_iterable_captured(db, train_dataset), resume=checkpoint_path.is_file())
def main(_run, batch_size, datasets, debug, experiment_dir, database_json): experiment_dir = Path(experiment_dir) if dlp_mpi.IS_MASTER: sacred.commands.print_config(_run) model = get_model() db = JsonDatabase(json_path=database_json) model.eval() with torch.no_grad(): summary = defaultdict(dict) for dataset in datasets: iterable = prepare_iterable( db, dataset, batch_size, return_keys=None, prefetch=False, ) for batch in dlp_mpi.split_managed(iterable, is_indexable=False, progress_bar=True, allow_single_worker=debug): entry = dict() model_output = model(pt.data.example_to_device(batch)) example_id = batch['example_id'][0] s = batch['s'][0] Y = batch['Y'][0] mask = model_output[0].numpy() Z = mask * Y[:, None, :] z = istft(einops.rearrange(Z, "t k f -> k t f"), size=512, shift=128) s = s[:, :z.shape[1]] z = z[:, :s.shape[1]] entry['metrics'] \ = pb_bss.evaluation.OutputMetrics(speech_prediction=z, speech_source=s).as_dict() summary[dataset][example_id] = entry summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER) if dlp_mpi.IS_MASTER: print(f'len(summary_list): {len(summary_list)}') for partial_summary in summary_list: for dataset, values in partial_summary.items(): summary[dataset].update(values) for dataset, values in summary.items(): print(f'{dataset}: {len(values)}') result_json_path = experiment_dir / 'result.json' print(f"Exporting result: {result_json_path}") pb.io.dump_json(summary, result_json_path)
def prepare_and_train(_run, _log, trainer, train_datasets, validate_datasets, lr_scheduler_step, lr_scheduler_gamma, load_model_from, database_jsons): trainer = get_trainer(trainer, load_model_from) if isinstance(database_jsons, str): database_jsons = database_jsons.split(',') db = JsonDatabase(database_jsons) # Perform a test run to check if everything works trainer.test_run( prepare_iterable_captured(db, train_datasets), prepare_iterable_captured(db, validate_datasets), ) # Register hooks and start the actual training trainer.register_validation_hook( prepare_iterable_captured(db, validate_datasets)) # Learning rate scheduler trainer.register_hook( pt.train.hooks.LRSchedulerHook( torch.optim.lr_scheduler.StepLR( trainer.optimizer.optimizer, step_size=lr_scheduler_step, gamma=lr_scheduler_gamma, ))) trainer.train(prepare_iterable_captured(db, train_datasets), resume=trainer.checkpoint_dir.exists())
def test_run(_config, _run, train_dataset, validate_dataset, database_json): sacred.commands.print_config(_run) trainer = pt.Trainer.from_config(_config["trainer"]) db = JsonDatabase(json_path=database_json) trainer.test_run( prepare_dataset_captured(db, train_dataset), prepare_dataset_captured(db, validate_dataset), )
def test_run(_run, _log, trainer, train_dataset, validate_dataset, load_model_from, database_json): trainer = get_trainer(trainer, load_model_from, _log) db = JsonDatabase(database_json) # Perform a test run to check if everything works trainer.test_run( prepare_dataset_captured(db, train_dataset, shuffle=True), prepare_dataset_captured(db, validate_dataset, shuffle=True), )
def test_run(storage_dir, database_json): model = SimpleMaskEstimator(513) print(f'Simple training for the following model: {model}') database = JsonDatabase(database_json) train_dataset = get_train_dataset(database) validation_dataset = get_validation_dataset(database) trainer = pt.train.trainer.Trainer(model, storage_dir, optimizer=pt.train.optimizer.Adam(), stop_trigger=(int(1e5), 'iteration')) trainer.test_run(train_dataset, validation_dataset)
def test_json_database_multiple(self): with tempfile.TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) d1_path = tmpdir / 'd1.json' with d1_path.open('w') as fd: json.dump(self.json, fd) d2_path = tmpdir / 'd2.json' with d2_path.open('w') as fd: json.dump(self.json2, fd) db = JsonDatabase(d1_path, d2_path) assert len(db.dataset_names) == 5 db = JsonDatabase([d1_path, d2_path]) assert len(db.dataset_names) == 5 db = JsonDatabase(json_path=d1_path) assert len(db.dataset_names) == 3 db = JsonDatabase(json_path=(d1_path, d2_path)) assert len(db.dataset_names) == 5 with pytest.raises(AssertionError): # Test metadata check _ = JsonDatabase(d2_path, d1_path).dataset_names with pytest.raises(AssertionError): # Test duplicate dataset name check _ = JsonDatabase(d1_path, d1_path).dataset_names with pytest.raises(AssertionError): DictDatabase([d1_path], d2_path)
def test_run(_run, _log, trainer, train_datasets, validate_datasets, load_model_from, database_jsons): trainer = get_trainer(trainer, load_model_from) if isinstance(database_jsons, str): database_jsons = database_jsons.split(',') db = JsonDatabase(database_jsons) # Perform a test run to check if everything works trainer.test_run( prepare_iterable_captured(db, train_datasets), prepare_iterable_captured(db, validate_datasets), )
def train(storage_dir, database_json): model = SimpleMaskEstimator(513) print(f'Simple training for the following model: {model}') database = JsonDatabase(database_json) train_dataset = get_train_dataset(database) validation_dataset = get_validation_dataset(database) trainer = pt.Trainer(model, storage_dir, optimizer=pt.train.optimizer.Adam(), stop_trigger=(int(1e5), 'iteration')) trainer.test_run(train_dataset, validation_dataset) trainer.register_validation_hook(validation_dataset, n_back_off=5, lr_update_factor=1 / 10, back_off_patience=1, early_stopping_patience=None) trainer.train(train_dataset)
def create_json(db_dir, intermediate_json_path, write_all, snr_range=(20, 30)): db = JsonDatabase(intermediate_json_path) json_dict = dict(datasets=dict()) database_dict = db.data['datasets'] if write_all: key_mapper = KEY_MAPPER else: key_mapper = {'observation': 'observation'} for dataset_name, dataset in database_dict.items(): dataset_dict = dict() for ex_id, ex in dataset.items(): for key, data_type in key_mapper.items(): current_path = db_dir / data_type / dataset_name if key in ['observation', 'noise_image']: ex['audio_path'][key] = str(current_path / f'{ex_id}.wav') else: ex['audio_path'][key] = [ str(current_path / f'{ex_id}_{k}.wav') for k in range(len(ex['speaker_id'])) ] ex['audio_path']['speech_source'] = [ # .../sms_wsj/cache/wsj_8k_zeromean/13-11.1/wsj1/si_tr_s/4ax/4axc0218.wav str(db_dir.joinpath(*Path(rir).parts[-6:])) for rir in ex['audio_path']['speech_source'] ] ex['audio_path']['rir'] = [ # .../sms_wsj/cache/rirs/train_si284/0/h_0.wav str(db_dir.joinpath(*Path(rir).parts[-4:])) for rir in ex['audio_path']['rir'] ] rng = _example_id_to_rng(ex_id) snr = rng.uniform(*snr_range) if 'dataset' in ex: del ex['dataset'] ex["snr"] = snr dataset_dict[ex_id] = ex json_dict['datasets'][dataset_name] = dataset_dict return json_dict
def prepare_and_train(_config, _run, train_dataset, validate_dataset, database_json): """ Prepares the train and validation dataset from the database object """ sacred.commands.print_config(_run) trainer = pt.Trainer.from_config(_config["trainer"]) checkpoint_path = trainer.checkpoint_dir / 'ckpt_latest.pth' db = JsonDatabase(json_path=database_json) print(repr(train_dataset), repr(validate_dataset)) trainer.test_run( prepare_dataset_captured(db, train_dataset), prepare_dataset_captured(db, validate_dataset), ) trainer.register_validation_hook( prepare_dataset_captured(db, validate_dataset)) trainer.train(prepare_dataset_captured(db, train_dataset), resume=checkpoint_path.is_file())
def prepare_and_train(_run, _log, trainer, train_dataset, validate_dataset, lr_scheduler_step, lr_scheduler_gamma, load_model_from, database_json): trainer = get_trainer(trainer, load_model_from) db = JsonDatabase(database_json) # Perform a test run to check if everything works trainer.test_run( prepare_iterable_captured(db, train_dataset), prepare_iterable_captured(db, validate_dataset), ) # Register hooks and start the actual training # Learning rate scheduler if lr_scheduler_step: trainer.register_hook(pt.train.hooks.LRSchedulerHook( torch.optim.lr_scheduler.StepLR( trainer.optimizer.optimizer, step_size=lr_scheduler_step, gamma=lr_scheduler_gamma, ) )) # Don't use LR back-off trainer.register_validation_hook( prepare_iterable_captured(db, validate_dataset), ) else: # Use LR back-off trainer.register_validation_hook( prepare_iterable_captured(db, validate_dataset), n_back_off=5, back_off_patience=3 ) trainer.train( prepare_iterable_captured(db, train_dataset), resume=trainer.checkpoint_dir.exists() )
def main(dst_dir, json_path, write_all, new_json_path, snr_range): json_path = Path(json_path).expanduser().resolve() dst_dir = Path(dst_dir).expanduser().resolve() if dlp_mpi.IS_MASTER: assert json_path.exists(), json_path dst_dir.mkdir(exist_ok=True, parents=True) if not any([(dst_dir / data_type).exists() for data_type in type_mapper.keys()]): write_files = True else: write_files = False num_wav_files = len(check_files(dst_dir)) if write_all and num_wav_files == (2 * 2 + 2) * 32000: print('Wav files seem to exist. They are not overwritten.') elif not write_all and num_wav_files == 32000 and ( dst_dir / 'observation').exists(): print('Wav files seem to exist. They are not overwritten.') else: raise ValueError( 'Not all wav files exist. However, the directory structure' ' already exists.') else: write_files = None write_files = dlp_mpi.COMM.bcast(write_files, root=dlp_mpi.MASTER) db = JsonDatabase(json_path) if write_files: write_wavs(dst_dir, db, write_all=write_all, snr_range=snr_range) if dlp_mpi.IS_MASTER and new_json_path: print(f'Creating a new json and saving it to {new_json_path}') new_json_path = Path(new_json_path).expanduser().resolve() updated_json = create_json(dst_dir, db, write_all, snr_range=snr_range) new_json_path.parent.mkdir(exist_ok=True, parents=True) with new_json_path.open('w') as f: json.dump(updated_json, f, indent=4, ensure_ascii=False) print(f'{json_path} written')
def run(_config, egs_path, json_path, stage, end_stage, gmm_dir, ali_data_type, train_data_type, target_speaker, channels, kaldi_cmd, num_jobs): sms_db = JsonDatabase(json_path) sms_kaldi_dir = Path(egs_path).resolve().expanduser() sms_kaldi_dir = sms_kaldi_dir / train_data_type / 's5' if stage <= 1 < end_stage: create_kaldi_dir(sms_kaldi_dir) if kaldi_cmd == 'ssh.pl': if 'CCS_NODEFILE' in os.environ: pc2_environ(sms_kaldi_dir) with (sms_kaldi_dir / 'cmd.sh').open('a') as fd: fd.writelines('export train_cmd="ssh.pl"') elif kaldi_cmd == 'run.pl': with (sms_kaldi_dir / 'cmd.sh').open('a') as fd: fd.writelines('export train_cmd="run.pl"') else: raise ValueError(kaldi_cmd) if gmm_dir is None: gmm = 'tri4b' else: gmm_dir = Path(gmm_dir) gmm = gmm_dir.name if stage <= 2 < end_stage: if gmm_dir is None: create_data_dir(sms_kaldi_dir, db=sms_db, data_type='wsj_8k', target_speaker=target_speaker) print('Start training tri3 model on wsj_8k') run_process([ f'{sms_kaldi_dir}/local_sms/get_tri3_model.bash', '--dest_dir', f'{sms_kaldi_dir}', '--nj', str(num_jobs) ], cwd=str(sms_kaldi_dir), stdout=None, stderr=None) else: assert gmm_dir.exists() gmm_parent_dir = sms_kaldi_dir / 'exp' / 'wsj_8k' gmm_parent_dir.mkdir(parents=True) shutil.copytree(gmm_dir, gmm_parent_dir / gmm) if stage <= 3 < end_stage and not ali_data_type == train_data_type: create_data_dir(sms_kaldi_dir, db=sms_db, data_type=ali_data_type, ref_channels=channels, target_speaker=target_speaker) if stage <= 4 < end_stage: create_data_dir(sms_kaldi_dir, db=sms_db, data_type=train_data_type, ref_channels=channels, target_speaker=target_speaker) if stage <= 16 < end_stage: print('Prepare data for nnet3 model training on sms_wsj') run_process([ f'{sms_kaldi_dir}/local_sms/prepare_nnet3_model_training.bash', '--dest_dir', f'{sms_kaldi_dir}', '--cv_sets', "cv_dev93", '--stage', str(stage), '--gmm_data_type', 'wsj_8k', '--gmm', gmm, '--ali_data_type', ali_data_type, '--dataset', train_data_type, '--nj', str(num_jobs) ], cwd=str(sms_kaldi_dir), stdout=None, stderr=None) if stage <= 20 and end_stage >= 17: print('Start training nnet3 model on sms_wsj') run_process([ f'{sms_kaldi_dir}/local_sms/get_nnet3_model.bash', '--dest_dir', f'{sms_kaldi_dir}', '--cv_sets', '"cv_dev93"', '--stage', str(stage), '--gmm_data_type', 'wsj_8k', '--gmm', gmm, '--ali_data_type', ali_data_type, '--dataset', train_data_type, '--nj', str(num_jobs) ], cwd=str(sms_kaldi_dir), stdout=None, stderr=None)
def _create_data_dir( get_wer_command_fn, kaldi_dir, db=None, json_path=None, dataset_names=None, data_type='wsj_8k', target_speaker=0, ref_channels=0, ): """ Args: get_wer_command_fn: kaldi_dir: db: json_path: dataset_names: data_type: target_speaker: ref_channels: Returns: """ assert not (db is None and json_path is None), (db, json_path) if db is None: db = JsonDatabase(json_path) kaldi_dir = Path(kaldi_dir).expanduser().resolve() data_dir = kaldi_dir / 'data' / data_type data_dir.mkdir(exist_ok=True, parents=True) if not isinstance(ref_channels, (list, tuple)): ref_channels = [ref_channels] example_id_to_wav = dict() example_id_to_speaker = dict() example_id_to_trans = dict() example_id_to_duration = dict() speaker_to_gender = defaultdict(lambda: defaultdict(list)) dataset_to_example_id = defaultdict(list) if dataset_names is None: dataset_names = ('train_si284', 'cv_dev93', 'test_eval92') elif isinstance(dataset_names, str): dataset_names = [dataset_names] if not isinstance(target_speaker, (list, tuple)): target_speaker = [target_speaker] assert not any([ (data_dir / dataset_name).exists() for dataset_name in dataset_names ]), ( 'One of the following directories already exists: ' f'{[data_dir / ds_name for ds_name in dataset_names]}\n' 'Delete them if you want to restart this stage' ) print( 'Create data dir for ' f'{", ".join([f"{data_type}/{ds_name}" for ds_name in dataset_names])} ' 'data' ) dataset = db.get_dataset(dataset_names) for example in dataset: for ref_ch in ref_channels: org_example_id = example['example_id'] dataset_name = example['dataset'] for t_spk in target_speaker: speaker_id = example['speaker_id'][t_spk] example_id = speaker_id + '_' + org_example_id example_id += f'_c{ref_ch}' if len(ref_channels) > 1 else '' example_id_to_wav[example_id] = get_wer_command_fn( example, ref_ch=ref_ch, spk=t_spk) try: transcription = example['kaldi_transcription'][t_spk] except KeyError: transcription = example['transcription'][t_spk] example_id_to_trans[example_id] = transcription example_id_to_speaker[example_id] = speaker_id gender = example['gender'][t_spk] speaker_to_gender[dataset_name][speaker_id] = gender if isinstance(example['num_samples'], dict): num_samples = example['num_samples']['observation'] else: num_samples = example['num_samples'] example_id_to_duration[ example_id] = f"{num_samples / SAMPLE_RATE:.2f}" dataset_to_example_id[dataset_name].append(example_id) assert len(example_id_to_speaker) > 0, dataset for dataset_name in dataset_names: path = data_dir / dataset_name path.mkdir(exist_ok=False, parents=False) for name, dictionary in ( ("utt2spk", example_id_to_speaker), ("text", example_id_to_trans), ("utt2dur", example_id_to_duration), ("wav.scp", example_id_to_wav) ): dictionary = {key: value for key, value in dictionary.items() if key in dataset_to_example_id[dataset_name]} assert len(dictionary) > 0, (dataset_name, name) if name == 'utt2dur': dump_keyed_lines(dictionary, path / 'reco2dur') dump_keyed_lines(dictionary, path / name) dictionary = speaker_to_gender[dataset_name] assert len(dictionary) > 0, (dataset_name, name) dump_keyed_lines(dictionary, path / 'spk2gender') run_process([ f'utils/fix_data_dir.sh', f'{path}'], cwd=str(kaldi_dir), stdout=None, stderr=None )
def run(_config, _run, audio_dir, kaldi_data_dir, json_path): assert Path(kaldi_root).exists(), kaldi_root assert len(ex.current_run.observers) == 1, ( 'FileObserver` missing. Add a `FileObserver` with `-F foo/bar/`.') base_dir = Path(ex.current_run.observers[0].basedir) base_dir = base_dir.expanduser().resolve() if audio_dir is not None: audio_dir = Path(audio_dir).expanduser().resolve() assert audio_dir.exists(), audio_dir json_path = Path(json_path).expanduser().resolve() assert json_path.exists(), json_path db = JsonDatabase(json_path) elif kaldi_data_dir is not None: kaldi_data_dir = Path(kaldi_data_dir).expanduser().resolve() assert kaldi_data_dir.exists(), kaldi_data_dir assert json_path is None, json_path elif json_path is not None: json_path = Path(json_path).expanduser().resolve() assert json_path.exists(), json_path db = JsonDatabase(json_path) else: raise ValueError('Either json_path, audio_dir or kaldi_data_dir has' 'to be defined.') if _config['model_egs_dir'] is None: model_egs_dir = kaldi_root / 'egs' / 'sms_wsj' / 's5' else: model_egs_dir = Path(_config['model_egs_dir']).expanduser().resolve() assert model_egs_dir.exists(), model_egs_dir dataset_names = _config['dataset_names'] if not isinstance(dataset_names, (tuple, list)): dataset_names = [dataset_names] data_type = _config['data_type'] if not isinstance(data_type, (tuple, list)): data_type = [data_type] kaldi_cmd = _config['kaldi_cmd'] if not base_dir == model_egs_dir and not (base_dir / 'steps').exists(): create_kaldi_dir(base_dir, model_egs_dir, exist_ok=True) if kaldi_cmd == 'ssh.pl': CCS_NODEFILE = Path(os.environ['CCS_NODEFILE']) (base_dir / '.queue').mkdir() (base_dir / '.queue' / 'machines').write_text( CCS_NODEFILE.read_text()) elif kaldi_cmd == 'run.pl': pass else: raise ValueError(kaldi_cmd) for d_type in data_type: for dset in dataset_names: dataset_dir = base_dir / 'data' / d_type / dset if audio_dir is not None: assert len(data_type) == 1, data_type create_dir(audio_dir, base_dir=base_dir, db=db, dataset_names=dset) elif kaldi_data_dir is None: create_data_dir(base_dir, db=db, data_type=d_type, dataset_names=dset, ref_channels=_config['ref_channels'], target_speaker=_config['target_speaker']) else: assert len(data_type) == 1, ( 'when using a predefined kaldi_data_dir not more then one ' 'data_type should be defined. Better use the decode' 'command directly') copytree(kaldi_data_dir / dset, dataset_dir, symlinks=True) run_process([f'utils/fix_data_dir.sh', f'{dataset_dir}'], cwd=str(base_dir), stdout=None, stderr=None) decode(base_dir=base_dir, model_egs_dir=model_egs_dir, dataset_dir=dataset_dir, model_dir=check_config_element(_config['model_dir']), ivector_dir=check_config_element(_config['ivector_dir']), extractor_dir=check_config_element( _config['extractor_dir']), data_type=d_type)
def main(_run, batch_size, datasets, debug, experiment_dir, database_json, _log): experiment_dir = Path(experiment_dir) if dlp_mpi.IS_MASTER: sacred.commands.print_config(_run) model = get_model() db = JsonDatabase(json_path=database_json) model.eval() with torch.no_grad(): summary = defaultdict(dict) for dataset_name in datasets: dataset = prepare_dataset(db, dataset_name, batch_size, return_keys=None, prefetch=False, shuffle=False) for batch in dlp_mpi.split_managed(dataset, is_indexable=True, progress_bar=True, allow_single_worker=debug): entry = dict() model_output = model(model.example_to_device(batch)) example_id = batch['example_id'][0] s = batch['s'][0] Y = batch['Y'][0] mask = model_output[0].numpy() Z = mask * Y[:, None, :] z = istft(einops.rearrange(Z, "t k f -> k t f"), size=512, shift=128) s = s[:, :z.shape[1]] z = z[:, :s.shape[1]] input_metrics = pb_bss.evaluation.InputMetrics( observation=batch['y'][0][None, :], speech_source=s, sample_rate=8000, enable_si_sdr=False, ) output_metrics = pb_bss.evaluation.OutputMetrics( speech_prediction=z, speech_source=s, sample_rate=8000, enable_si_sdr=False, ) entry['input'] = dict(mir_eval=input_metrics.mir_eval, ) entry['output'] = dict(mir_eval={ k: v for k, v in output_metrics.mir_eval.items() if k != 'selection' }, ) entry['improvement'] = pb.utils.nested.nested_op( operator.sub, entry['output'], entry['input'], ) entry['selection'] = output_metrics.mir_eval['selection'] summary[dataset][example_id] = entry summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER) if dlp_mpi.IS_MASTER: _log.info(f'len(summary_list): {len(summary_list)}') summary = pb.utils.nested.nested_merge(*summary_list) for dataset, values in summary.items(): _log.info(f'{dataset}: {len(values)}') assert len(values) == len( db.get_dataset(dataset) ), 'Number of results needs to match length of dataset!' result_json_path = experiment_dir / 'result.json' _log.info(f"Exporting result: {result_json_path}") pb.io.dump_json(summary, result_json_path) # Compute and save mean of metrics means = compute_means(summary) mean_json_path = experiment_dir / 'means.json' _log.info(f"Saving means to: {mean_json_path}") pb.io.dump_json(means, mean_json_path)
def main(_run, datasets, debug, experiment_dir, export_audio, sample_rate, _log, database_json, oracle_num_spk, max_iterations): experiment_dir = Path(experiment_dir) if mpi.IS_MASTER: sacred.commands.print_config(_run) dump_config_and_makefile() model = get_model() db = JsonDatabase(database_json) model.eval() with torch.no_grad(): summary = defaultdict(dict) for dataset in datasets: iterable = prepare_iterable( db, dataset, 1, chunk_size=-1, prefetch=False, shuffle=False, iterator_slice=slice(mpi.RANK, 20 if debug else None, mpi.SIZE), ) if export_audio: (experiment_dir / 'audio' / dataset).mkdir(parents=True, exist_ok=True) for batch in tqdm( iterable, total=len(iterable), disable=not mpi.IS_MASTER, desc=dataset, ): example_id = batch['example_id'][0] summary[dataset][example_id] = entry = dict() oracle_speaker_count = \ entry['oracle_speaker_count'] = batch['s'][0].shape[0] try: model_output = model.decode( pt.data.example_to_device(batch), max_iterations=max_iterations, oracle_num_speakers=oracle_speaker_count if oracle_num_spk else None) # Bring to numpy float64 for evaluation metrics computation s = batch['s'][0].astype(np.float64) z = model_output['out'][0].cpu().numpy().astype(np.float64) estimated_speaker_count = \ entry['estimated_speaker_count'] = z.shape[0] entry['source_counting_accuracy'] = \ estimated_speaker_count == oracle_speaker_count if oracle_speaker_count == estimated_speaker_count: # These evaluations don't work if the number of # speakers in s and z don't match entry['mir_eval'] = pb_bss.evaluation.mir_eval_sources( s, z, return_dict=True) # Get the correct order for si_sdr and saving z = z[entry['mir_eval']['selection']] entry['si_sdr'] = pb_bss.evaluation.si_sdr(s, z) else: warnings.warn( 'The number of speakers is estimated incorrectly ' 'for some examples! The calculated SDR values ' 'might not be representative!') if export_audio: entry['audio_path'] = batch['audio_path'] entry['audio_path'].setdefault('estimated', []) for k, audio in enumerate(z): audio_path = (experiment_dir / 'audio' / dataset / f'{example_id}_{k}.wav') pb.io.dump_audio(audio, audio_path, sample_rate=sample_rate) entry['audio_path']['estimated'].append(audio_path) except: _log.error(f'Exception was raised in example with ID ' f'"{example_id}"') raise summary_list = mpi.gather(summary, root=mpi.MASTER) if mpi.IS_MASTER: # Combine all summaries to one for partial_summary in summary_list: for dataset, values in partial_summary.items(): summary[dataset].update(values) for dataset, values in summary.items(): _log.info(f'{dataset}: {len(values)}') # Write summary to JSON result_json_path = experiment_dir / 'result.json' _log.info(f"Exporting result: {result_json_path}") pb.io.dump_json(summary, result_json_path) # Compute means for some metrics mean_keys = [ 'mir_eval.sdr', 'mir_eval.sar', 'mir_eval.sir', 'si_sdr', 'source_counting_accuracy' ] means = {} for dataset, dataset_results in summary.items(): means[dataset] = {} flattened = { k: pb.utils.nested.flatten(v) for k, v in dataset_results.items() } for mean_key in mean_keys: try: means[dataset][mean_key] = np.mean( np.array([v[mean_key] for v in flattened.values()])) except KeyError: warnings.warn(f'Couldn\'t compute mean for {mean_key}.') means[dataset] = pb.utils.nested.deflatten(means[dataset]) mean_json_path = experiment_dir / 'means.json' _log.info(f'Exporting means: {mean_json_path}') pb.io.dump_json(means, mean_json_path) _log.info('Resulting means:') pprint(means)
def evaluate(checkpoint_path, eval_dir, database_json): model = SimpleMaskEstimator(513) model.load_checkpoint( checkpoint_path=checkpoint_path, in_checkpoint_path='model', consider_mpi=True ) model.eval() if dlp_mpi.IS_MASTER: print(f'Start to evaluate the checkpoint {checkpoint_path.resolve()} ' f'and will write the evaluation result to' f' {eval_dir / "result.json"}') database = JsonDatabase(database_json) test_dataset = get_test_dataset(database) with torch.no_grad(): summary = dict(masked=dict(), beamformed=dict(), observed=dict()) for batch in dlp_mpi.split_managed( test_dataset, is_indexable=True, progress_bar=True, allow_single_worker=True ): model_output = model(pt.data.example_to_device(batch)) example_id = batch['example_id'] s = batch['speech_source'][0][None] speech_mask = model_output['speech_mask_prediction'].numpy() Y = batch['observation_stft'] Z_mask = speech_mask[0] * Y[0] z_mask = pb.transform.istft(Z_mask)[None] speech_mask = np.median(speech_mask, axis=0).T noise_mask = model_output['noise_mask_prediction'].numpy() noise_mask = np.median(noise_mask, axis=0).T Y = rearrange(Y, 'c t f -> f c t') target_psd = pb_bss.extraction.get_power_spectral_density_matrix( Y, speech_mask, ) noise_psd = pb_bss.extraction.get_power_spectral_density_matrix( Y, noise_mask, ) beamformer = pb_bss.extraction.get_bf_vector( 'mvdr_souden', target_psd_matrix=target_psd, noise_psd_matrix=noise_psd ) Z_bf = pb_bss.extraction.apply_beamforming_vector(beamformer, Y).T z_bf = pb.transform.istft(Z_bf)[None] y = batch['observation'][0][None] s = s[:, :z_bf.shape[1]] for key, signal in zip(summary.keys(), [z_mask, z_bf, y]): signal = signal[:, :s.shape[1]] entry = pb_bss.evaluation.OutputMetrics( speech_prediction=signal, speech_source=s, sample_rate=16000 ).as_dict() entry.pop('mir_eval_selection') summary[key][example_id] = entry summary_list = dlp_mpi.COMM.gather(summary, root=dlp_mpi.MASTER) if dlp_mpi.IS_MASTER: print(f'\n len(summary_list): {len(summary_list)}') summary = dict(masked=dict(), beamformed=dict(), observed=dict()) for partial_summary in summary_list: for signal_type, metric in partial_summary.items(): summary[signal_type].update(metric) for signal_type, values in summary.items(): print(signal_type) for metric in next(iter(values.values())).keys(): mean = np.mean([value[metric] for key, value in values.items() if '_mean' not in key]) values[metric + '_mean'] = mean print(f'{metric}: {mean}') result_json_path = eval_dir / 'result.json' print(f"Exporting result: {result_json_path}") pb.io.dump_json(summary, result_json_path)
def get_test_dataset(database: JsonDatabase): val_iterator = database.get_dataset('et05_simu') return val_iterator.map(prepare_data)
def main(json_path: Path, rir_dir: Path, wsj_json_path: Path, num_speakers): wsj_json_path = Path(wsj_json_path).expanduser().resolve() json_path = Path(json_path).expanduser().resolve() if json_path.exists(): raise FileExistsError(json_path) rir_dir = Path(rir_dir).expanduser().resolve() assert wsj_json_path.is_file(), json_path assert rir_dir.exists(), rir_dir setup = dict( train_si284=dict(source_dataset_name="train_si284"), cv_dev93=dict(source_dataset_name="cv_dev93"), test_eval92=dict(source_dataset_name="test_eval92"), ) rir_db = JsonDatabase(rir_dir / "scenarios.json") source_db = JsonDatabase(wsj_json_path) target_db = dict() target_db['datasets'] = defaultdict(dict) for dataset_name in setup.keys(): source_dataset_name = setup[dataset_name]["source_dataset_name"] source_iterator = source_db.get_dataset(source_dataset_name) print(f'length of source {dataset_name}: {len(source_iterator)}') source_iterator = source_iterator.filter( filter_fn=filter_punctuation_pronunciation, lazy=False) print(f'length of source {dataset_name}: {len(source_iterator)} ' '(after punctuation filter)') rir_iterator = rir_db.get_dataset(dataset_name) assert len(rir_iterator) % len(source_iterator) == 0, ( f'To avoid a bias towards certain utterance the len ' f'rir_iterator ({len(rir_iterator)}) should be an integer ' f'multiple of len source_iterator ({len(source_iterator)}).') print(f'length of rir {dataset_name}: {len(rir_iterator)}') probe_path = rir_dir / dataset_name / "0" available_speaker_positions = len(list(probe_path.glob('h_*.wav'))) assert num_speakers <= available_speaker_positions, ( f'Requested {num_speakers} num_speakers, while found only ' f'{available_speaker_positions} rirs in {probe_path}.') info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav")) sample_rate_rir = info.samplerate ex_wsj = source_iterator.random_choice(1)[0] info = soundfile.SoundFile(ex_wsj['audio_path']['observation']) sample_rate_wsj = info.samplerate assert sample_rate_rir == sample_rate_wsj, (sample_rate_rir, sample_rate_wsj) rir_iterator = rir_iterator.sort( sort_fn=functools.partial(sorted, key=int)) source_iterator = source_iterator.sort() assert len(rir_iterator) % len(source_iterator) == 0 repeats = len(rir_iterator) // len(source_iterator) source_iterator = source_iterator.tile(repeats) speaker_ids = [example['speaker_id'] for example in source_iterator] rng = get_rng(dataset_name, 'example_compositions') example_compositions = None for _ in range(num_speakers): example_compositions = extend_example_composition_greedy( rng, speaker_ids, example_compositions=example_compositions, ) ex_dict = dict() assert len(rir_iterator) == len(example_compositions) for rir_example, example_composition in zip(rir_iterator, example_compositions): source_examples = source_iterator[example_composition] example_id = "_".join([ rir_example['example_id'], *[ex["example_id"] for ex in source_examples], ]) rng = get_rng(dataset_name, example_id) example = get_randomized_example( rir_example, source_examples, rng, dataset_name, rir_dir, ) ex_dict[example_id] = example target_db['datasets'][dataset_name] = ex_dict json_path.parent.mkdir(exist_ok=True, parents=True) with json_path.open('w') as f: json.dump(target_db, f, indent=2, ensure_ascii=False) print(f'{json_path} written')
class DataProvider(Configurable): json_path: str audio_reader: Callable train_set: dict validate_set: str = None cached_datasets: list = None min_audio_length: float = 1. train_segmenter: float = None test_segmenter: float = None train_transform: Callable = None test_transform: Callable = None train_fetcher: Callable = None test_fetcher: Callable = None label_key: str = 'events' discard_labelless_train_examples: bool = True storage_dir: str = None # augmentation min_class_examples_per_epoch: int = 0 scale_sampling_fn: Callable = None mix_interval: float = 1.5 mix_fn: Callable = None def __post_init__(self): assert self.json_path is not None self.db = JsonDatabase(json_path=self.json_path) def get_train_set(self, filter_example_ids=None): return self.get_dataset(self.train_set, train=True, filter_example_ids=filter_example_ids) def get_validate_set(self, filter_example_ids=None): return self.get_dataset(self.validate_set, train=False, filter_example_ids=filter_example_ids) def get_dataset(self, dataset_names_or_raw_datasets, train=False, filter_example_ids=None): ds = self.prepare_audio(dataset_names_or_raw_datasets, train=train, filter_example_ids=filter_example_ids) ds = self.segment_transform_and_fetch(ds, train=train) return ds def prepare_audio(self, dataset_names_or_raw_datasets, train=False, filter_example_ids=None): individual_audio_datasets = self._load_audio( dataset_names_or_raw_datasets, train=train, filter_example_ids=filter_example_ids) if not isinstance(individual_audio_datasets, list): assert isinstance( individual_audio_datasets, lazy_dataset.Dataset), type(individual_audio_datasets) individual_audio_datasets = [(individual_audio_datasets, 1)] combined_audio_dataset = self._tile_and_intersperse( individual_audio_datasets, shuffle=train) if train and self.min_class_examples_per_epoch > 0: assert self.label_key is not None raw_datasets = self.get_raw( dataset_names_or_raw_datasets, discard_labelless_examples=self. discard_labelless_train_examples, filter_example_ids=filter_example_ids, ) label_counts, labels = self._count_labels(raw_datasets, self.label_key) label_reps = self._compute_label_repetitions( label_counts, min_counts=self.min_class_examples_per_epoch) repetition_groups = self._build_repetition_groups( individual_audio_datasets, labels, label_reps) dataset = self._tile_and_intersperse(repetition_groups, shuffle=train) else: dataset = combined_audio_dataset if train: # dataset = self.scale_and_mix(dataset, combined_audio_dataset) dataset = self.scale_and_mix(dataset, dataset) print(f'Total data set length:', len(dataset)) return dataset def _load_audio(self, dataset_names_or_raw_datasets, train=False, filter_example_ids=None, idx=None): if isinstance(dataset_names_or_raw_datasets, (dict, list, tuple)): ds = [] for i, name_or_ds in enumerate(dataset_names_or_raw_datasets): num_reps = ( dataset_names_or_raw_datasets[name_or_ds] if isinstance( dataset_names_or_raw_datasets, dict) else name_or_ds[1] if isinstance(name_or_ds, (list, tuple)) else 1) if num_reps == 0: continue ds.append((self._load_audio( name_or_ds[0] if isinstance(name_or_ds, (list, tuple)) else name_or_ds, train=train, filter_example_ids=filter_example_ids, idx=i, ), num_reps)) return ds ds = self.get_raw( dataset_names_or_raw_datasets, discard_labelless_examples=(train and self.discard_labelless_train_examples), filter_example_ids=filter_example_ids, ).map(self.audio_reader) cache = (self.cached_datasets is not None and isinstance(dataset_names_or_raw_datasets, str) and dataset_names_or_raw_datasets in self.cached_datasets) if cache: ds = ds.cache(lazy=False) if isinstance(dataset_names_or_raw_datasets, str): ds_name = " " + dataset_names_or_raw_datasets else: ds_name = "" if idx is not None: ds_name += f" [{idx}]" print(f'Single data set length{ds_name}:', len(ds)) return ds def get_raw( self, dataset_names_or_raw_datasets, discard_labelless_examples=False, filter_example_ids=None, ): if isinstance(dataset_names_or_raw_datasets, (dict, list, tuple)): return list( filter(lambda x: x[1] > 0, [( self.get_raw( name_or_ds[0] if isinstance(name_or_ds, (list, tuple)) else name_or_ds, discard_labelless_examples=discard_labelless_examples, filter_example_ids=filter_example_ids, ), (dataset_names_or_raw_datasets[name_or_ds] if isinstance( dataset_names_or_raw_datasets, dict) else name_or_ds[1] if isinstance(name_or_ds, (list, tuple)) else 1), ) for name_or_ds in dataset_names_or_raw_datasets])) elif isinstance(dataset_names_or_raw_datasets, str): ds = self.db.get_dataset(dataset_names_or_raw_datasets) else: assert isinstance( dataset_names_or_raw_datasets, lazy_dataset.Dataset), type(dataset_names_or_raw_datasets) ds = dataset_names_or_raw_datasets if discard_labelless_examples: ds = ds.filter( lambda ex: self.label_key in ex and ex[self.label_key], lazy=False) if filter_example_ids is not None: ds = ds.filter( lambda ex: ex['example_id'] not in filter_example_ids, lazy=False) return ds.filter(lambda ex: 'audio_length' in ex and ex['audio_length'] > self.min_audio_length, lazy=False) @staticmethod def _tile_and_intersperse(datasets, shuffle=False): if shuffle: datasets = [(ds.shuffle(reshuffle=True), reps) for ds, reps in datasets] return lazy_dataset.intersperse( *[ds.tile(reps) for ds, reps in datasets]) def scale_and_mix(self, dataset, mixin_dataset=None): if mixin_dataset is None: mixin_dataset = dataset if self.scale_sampling_fn is not None: def scale(example): w = self.scale_sampling_fn() example['audio_data'] = example['audio_data'] * w return example dataset = dataset.map(scale) mixin_dataset = mixin_dataset.map(scale) if self.mix_interval is not None: # mixin_dataset = mixin_dataset.tile( # math.ceil(len(dataset)/len(combined_audio_dataset))) assert self.mix_fn is not None dataset = MixtureDataset(dataset, mixin_dataset, mix_interval=self.mix_interval, mix_fn=self.mix_fn) return dataset def _count_labels(self, raw_datasets, label_key, label_counts=None, reps=1): if label_counts is None: label_counts = defaultdict(lambda: 0) if isinstance(raw_datasets, list): labels = [] for ds, ds_reps in raw_datasets: label_counts, cur_labels = self._count_labels( ds, label_key, label_counts=label_counts, reps=ds_reps * reps) labels.append(cur_labels) return label_counts, labels labels = [] for example in raw_datasets: cur_labels = sorted(set(to_list(example[label_key]))) labels.append(cur_labels) for label in cur_labels: label_counts[label] += reps # print(label_counts) return label_counts, labels @staticmethod def _compute_label_repetitions(label_counts, min_counts): max_count = max(label_counts.values()) if isinstance(min_counts, float): assert 0. < min_counts < 1., min_counts min_counts = math.ceil(max_count * min_counts) assert isinstance(min_counts, int) and min_counts > 1, min_counts assert min_counts - 1 <= 0.9 * max_count, (min_counts, max_count) base_rep = 1 // (1 - (min_counts - 1) / max_count) min_counts *= base_rep label_repetitions = { label: math.ceil(min_counts / count) for label, count in label_counts.items() } return label_repetitions def _build_repetition_groups(self, dataset, labels, label_repetitions): assert len(dataset) == len(labels), (len(dataset), len(labels)) if isinstance(dataset, list): return [(group_ds, ds_reps * group_reps) for (ds, ds_reps), cur_labels in zip(dataset, labels) for group_ds, group_reps in self._build_repetition_groups( ds, cur_labels, label_repetitions)] idx_reps = [ max([label_repetitions[label] for label in idx_labels]) for idx_labels in labels ] rep_groups = {} for n_reps in set(idx_reps): rep_groups[n_reps] = np.argwhere( np.array(idx_reps) == n_reps).flatten().tolist() datasets = [] for n_reps, indices in sorted(rep_groups.items(), key=lambda x: x[0]): datasets.append((dataset[sorted(indices)], n_reps)) # ds = lazy_dataset.intersperse(*datasets) return datasets def segment_transform_and_fetch( self, dataset, segment=True, transform=True, fetch=True, train=False, ): segmenter = self.train_segmenter if train else self.test_segmenter segment = segment and segmenter is not None if segment: dataset = dataset.map(segmenter) if transform: transform = self.train_transform if train else self.test_transform assert transform is not None if segment: dataset = dataset.batch_map(transform) else: dataset = dataset.map(transform) if fetch: fetcher = self.train_fetcher if train else self.test_fetcher assert fetcher is not None dataset = fetcher(dataset, batched_input=segment) return dataset @classmethod def finalize_dogmatic_config(cls, config): config['audio_reader'] = { 'factory': AudioReader, 'source_sample_rate': None, 'target_sample_rate': 16000, 'average_channels': True, 'normalization_domain': 'instance', 'normalization_type': 'max', 'alignment_keys': ['events'], } config['train_transform'] = { 'factory': Transform, 'stft': { 'factory': STFT, 'shift': 320, 'window_length': 960, 'size': 1024, 'fading': 'half', 'pad': True, 'alignment_keys': ['events'], }, 'label_encoder': { 'factory': MultiHotAlignmentEncoder, 'label_key': 'events', 'storage_dir': config['storage_dir'], }, 'anchor_sampling_fn': { 'factory': Uniform, 'low': 0.4, 'high': 0.6, }, 'anchor_shift_sampling_fn': { 'factory': Uniform, 'low': -0.1, 'high': 0.1, }, } config['test_transform'] = { 'factory': Transform, 'stft': config['train_transform']['stft'], 'label_encoder': config['train_transform']['label_encoder'], } config['train_fetcher'] = { 'factory': DataFetcher, 'prefetch_workers': 16, 'batch_size': 16, 'max_padding_rate': .05, 'drop_incomplete': True, 'global_shuffle': False, # already shuffled in prepare_audio } config['train_fetcher']['bucket_expiration'] = ( 2000 * config['train_fetcher']['batch_size']) config['test_fetcher'] = { 'factory': DataFetcher, 'prefetch_workers': config['train_fetcher']['prefetch_workers'], 'batch_size': 2 * config['train_fetcher']['batch_size'], 'max_padding_rate': config['train_fetcher']['max_padding_rate'], 'bucket_expiration': config['train_fetcher']['bucket_expiration'], 'drop_incomplete': False, 'global_shuffle': False, } config['scale_sampling_fn'] = { 'factory': LogTruncatedNormal, 'loc': 0., 'scale': 1., 'truncation': np.log(3.), } if config['mix_interval'] is not None: config['mix_fn'] = { 'factory': SuperposeEvents, 'min_overlap': 1., 'fade_length': config['train_transform']['stft']['window_length'], 'label_key': 'events', }
def main(_run, exp_dir, storage_dir, database_json, test_set, max_examples, device): if IS_MASTER: commands.print_config(_run) exp_dir = Path(exp_dir) storage_dir = Path(storage_dir) audio_dir = storage_dir / 'audio' audio_dir.mkdir(parents=True) config = load_json(exp_dir / 'config.json') model = Model.from_storage_dir(exp_dir, consider_mpi=True) model.to(device) model.eval() db = JsonDatabase(database_json) test_data = db.get_dataset(test_set) if max_examples is not None: test_data = test_data.shuffle( rng=np.random.RandomState(0))[:max_examples] test_data = prepare_dataset(test_data, audio_reader=config['audio_reader'], stft=config['stft'], max_length=None, batch_size=1, shuffle=True) squared_err = list() with torch.no_grad(): for example in split_managed(test_data, is_indexable=False, progress_bar=True, allow_single_worker=True): example = model.example_to_device(example, device) target = example['audio_data'].squeeze(1) x = model.feature_extraction(example['stft'], example['seq_len']) x = model.wavenet.infer( x.squeeze(1), chunk_length=80_000, chunk_overlap=16_000, ) assert target.shape == x.shape, (target.shape, x.shape) squared_err.extend([(ex_id, mse.cpu().detach().numpy(), x.shape[1]) for ex_id, mse in zip(example['example_id'], (( x - target)**2).sum(1))]) squared_err_list = COMM.gather(squared_err, root=MASTER) if IS_MASTER: print(f'\nlen(squared_err_list): {len(squared_err_list)}') squared_err = [] for i in range(len(squared_err_list)): squared_err.extend(squared_err_list[i]) _, err, t = list(zip(*squared_err)) print('rmse:', np.sqrt(np.sum(err) / np.sum(t))) rmse = sorted([(ex_id, np.sqrt(err / t)) for ex_id, err, t in squared_err], key=lambda x: x[1]) dump_json(rmse, storage_dir / 'rmse.json', indent=4, sort_keys=False) ex_ids_ordered = [x[0] for x in rmse] test_data = db.get_dataset('test_clean').shuffle( rng=np.random.RandomState(0))[:max_examples].filter(lambda x: x[ 'example_id'] in ex_ids_ordered[:10] + ex_ids_ordered[-10:], lazy=False) test_data = prepare_dataset(test_data, audio_reader=config['audio_reader'], stft=config['stft'], max_length=10., batch_size=1, shuffle=True) with torch.no_grad(): for example in test_data: example = model.example_to_device(example, device) x = model.feature_extraction(example['stft'], example['seq_len']) x = model.wavenet.infer( x.squeeze(1), chunk_length=80_000, chunk_overlap=16_000, ) for i, audio in enumerate(x.cpu().detach().numpy()): wavfile.write( str(audio_dir / f'{example["example_id"][i]}.wav'), model.sample_rate, audio)
def main(_run, datasets, debug, experiment_dir, dump_audio, sample_rate, _log, database_json): experiment_dir = Path(experiment_dir) if dlp_mpi.IS_MASTER: sacred.commands.print_config(_run) dump_config_and_makefile() model = get_model() db = JsonDatabase(database_json) model.eval() results = defaultdict(dict) with torch.no_grad(): for dataset in datasets: iterable = prepare_dataset( db, dataset, 1, chunk_size=-1, prefetch=False, shuffle=False, dataset_slice=slice(dlp_mpi.RANK, 20 if debug else None, dlp_mpi.SIZE), ) if dump_audio: (experiment_dir / 'audio' / dataset).mkdir(parents=True, exist_ok=True) for batch in tqdm(iterable, total=len(iterable), disable=not dlp_mpi.IS_MASTER): example_id = batch['example_id'][0] results[dataset][example_id] = entry = dict() try: model_output = model(model.example_to_device(batch)) # Bring to numpy float64 for evaluation metrics computation observation = batch['y'][0].astype(np.float64)[None, ] speech_prediction = ( model_output['out'][0].cpu().numpy().astype( np.float64)) speech_source = batch['s'][0].astype(np.float64) input_metrics = pb_bss.evaluation.InputMetrics( observation=observation, speech_source=speech_source, sample_rate=sample_rate, enable_si_sdr=True, ) output_metrics = pb_bss.evaluation.OutputMetrics( speech_prediction=speech_prediction, speech_source=speech_source, sample_rate=sample_rate, enable_si_sdr=True, ) # Select the metrics to compute entry['input'] = dict( mir_eval=input_metrics.mir_eval, si_sdr=input_metrics.si_sdr, # TODO: stoi fails with short speech segments (https://github.com/mpariente/pystoi/issues/21) # stoi=input_metrics.stoi, # TODO: pesq creates "Processing error" messages # pesq=input_metrics.pesq, ) # Remove selection from mir_eval dict to enable # recursive calculation of improvement entry['output'] = dict( mir_eval={ k: v for k, v in output_metrics.mir_eval.items() if k != 'selection' }, si_sdr=output_metrics.si_sdr, # stoi=output_metrics.stoi, # pesq=output_metrics.pesq, ) entry['improvement'] = pb.utils.nested.nested_op( operator.sub, entry['output'], entry['input'], ) entry['selection'] = output_metrics.mir_eval['selection'] if dump_audio: entry['audio_path'] = batch['audio_path'] entry['audio_path']['estimated'] = [] for k, audio in enumerate(speech_prediction): audio_path = (experiment_dir / 'audio' / dataset / f'{example_id}_{k}.wav') pb.io.dump_audio(audio, audio_path, sample_rate=sample_rate) entry['audio_path']['estimated'].append(audio_path) except: _log.error(f'Exception was raised in example with ID ' f'"{example_id}"') raise results = dlp_mpi.gather(results, root=dlp_mpi.MASTER) if dlp_mpi.IS_MASTER: # Combine all results to one. This function raises an exception if it # finds duplicate keys results = pb.utils.nested.nested_merge(*results) for dataset, values in results.items(): _log.info(f'{dataset}: {len(values)}') # Write results to JSON result_json_path = experiment_dir / 'result.json' _log.info(f"Exporting result: {result_json_path}") pb.io.dump_json(results, result_json_path) # Compute means for some metrics means = compute_means(results) mean_json_path = experiment_dir / 'means.json' _log.info(f'Exporting means: {mean_json_path}') pb.io.dump_json(means, mean_json_path) _log.info('Resulting means:') pprint(means)
def write_wavs(dst_dir, json_path, write_all=False, snr_range=(20, 30)): db = JsonDatabase(json_path) if write_all: if dlp_mpi.IS_MASTER: [(dst_dir / data_type).mkdir(exist_ok=False) for data_type in KEY_MAPPER.values()] map_fn = partial(scenario_map_fn, snr_range=snr_range, sync_speech_source=True, add_speech_reverberation_early=True, add_speech_reverberation_tail=True) else: if dlp_mpi.IS_MASTER: (dst_dir / 'observation').mkdir(exist_ok=False) map_fn = partial(scenario_map_fn, snr_range=snr_range, sync_speech_source=True, add_speech_reverberation_early=False, add_speech_reverberation_tail=False) for dataset in ['train_si284', 'cv_dev93', 'test_eval92']: if dlp_mpi.IS_MASTER: [(dst_dir / data_type / dataset).mkdir(exist_ok=False) for data_type in KEY_MAPPER.values()] ds = db.get_dataset(dataset).map(audio_read).map(map_fn) for example in dlp_mpi.split_managed( ds, is_indexable=True, allow_single_worker=True, progress_bar=True, ): audio_dict = example['audio_data'] example_id = example['example_id'] if not write_all: del audio_dict['speech_reverberation_early'] del audio_dict['speech_reverberation_tail'] del audio_dict['noise_image'] def get_abs_max(a): if isinstance(a, np.ndarray): if a.dtype == np.object: return np.max(list(map(get_abs_max, a))) else: return np.max(np.abs(a)) elif isinstance(a, (tuple, list)): return np.max(list(map(get_abs_max, a))) elif isinstance(a, dict): return np.max(list(map(get_abs_max, a.values()))) else: raise TypeError(a) assert get_abs_max(audio_dict), (example_id, { k: get_abs_max(v) for k, v in audio_dict.items() }) for key, value in audio_dict.items(): if key not in KEY_MAPPER: continue path = dst_dir / KEY_MAPPER[key] / dataset if key in ['observation', 'noise_image']: value = value[None] for idx, signal in enumerate(value): appendix = f'_{idx}' if len(value) > 1 else '' filename = example_id + appendix + '.wav' audio_path = str(path / filename) with soundfile.SoundFile(audio_path, subtype='FLOAT', mode='w', samplerate=8000, channels=1 if signal.ndim == 1 else signal.shape[0]) as f: f.write(signal.T) dlp_mpi.barrier() if dlp_mpi.IS_MASTER: created_files = check_files(dst_dir) print(f"Written {len(created_files)} wav files.") if write_all: # TODO Less, if you do a test run. num_speakers = 2 # todo infer num_speakers from json # 2 files for: early, tail, speech_source # 1 file for: observation, noise expect = (3 * num_speakers + 2) * 35875 assert len(created_files) == expect, (len(created_files), expect) else: assert len(created_files) == 35875, len(created_files)
def __post_init__(self): assert self.json_path is not None self.db = JsonDatabase(json_path=self.json_path)
def main( json_path: Path, rir_dir: Path, wsj_json_path: Path, num_speakers: int, debug: bool, ): wsj_json_path = Path(wsj_json_path).expanduser().resolve() json_path = Path(json_path).expanduser().resolve() rir_dir = Path(rir_dir).expanduser().resolve() assert wsj_json_path.is_file(), json_path assert rir_dir.exists(), rir_dir # ToDo: What was the motivation for defining this "setup"? setup = dict( train_si284=dict(source_dataset_name="train_si284"), cv_dev93=dict(source_dataset_name="cv_dev93"), test_eval92=dict(source_dataset_name="test_eval92"), ) rir_db = JsonDatabase(rir_dir / "scenarios.json") source_db = JsonDatabase(wsj_json_path) target_db = dict() target_db['datasets'] = defaultdict(dict) for dataset_name in setup.keys(): source_dataset_name = setup[dataset_name]["source_dataset_name"] source_dataset = source_db.get_dataset(source_dataset_name) print(f'length of source {dataset_name}: {len(source_dataset)}') source_dataset = source_dataset.filter( filter_fn=filter_punctuation_pronunciation, lazy=False) print(f'length of source {dataset_name}: {len(source_dataset)} ' '(after punctuation filter)') def add_rir_path(rir_ex): assert 'audio_path' not in rir_ex, rir_ex example_id = rir_ex['example_id'] rir_ex['audio_path'] = { 'rir': [ str(rir_dir / dataset_name / example_id / f"h_{k}.wav") for k in range(num_speakers) ] } return rir_ex rir_dataset = rir_db.get_dataset(dataset_name).map(add_rir_path) assert len(rir_dataset) % len(source_dataset) == 0, ( f'To avoid a bias towards certain utterance the len ' f'rir_dataset ({len(rir_dataset)}) should be an integer ' f'multiple of len source_dataset ({len(source_dataset)}).') print(f'length of rir {dataset_name}: {len(rir_dataset)}') probe_path = rir_dir / dataset_name / "0" available_speaker_positions = len(list(probe_path.glob('h_*.wav'))) assert num_speakers <= available_speaker_positions, ( f'Requested {num_speakers} num_speakers, while found only ' f'{available_speaker_positions} rirs in {probe_path}.') info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav")) sample_rate_rir = info.samplerate ex_wsj = source_dataset.random_choice(1)[0] info = soundfile.SoundFile(ex_wsj['audio_path']['observation']) sample_rate_wsj = info.samplerate assert sample_rate_rir == sample_rate_wsj, (sample_rate_rir, sample_rate_wsj) if debug: rir_dataset = rir_dataset[:DEBUG_EXAMPLE_LIMIT] # Use step_size to avoid that only one speaker is in # source_iterator. step_size = len(source_dataset) // DEBUG_EXAMPLE_LIMIT source_dataset = source_dataset[::step_size] ex_dict = combine_rirs_and_sources( rir_dataset=rir_dataset, source_dataset=source_dataset, num_speakers=num_speakers, dataset_name=dataset_name, ) target_db['datasets'][dataset_name] = ex_dict json_path.parent.mkdir(exist_ok=True, parents=True) with json_path.open('w') as f: json.dump(target_db, f, indent=2, ensure_ascii=False) print(f'{json_path} written.')