def combine_rirs_and_sources( rir_dataset, source_dataset, num_speakers, dataset_name, ): # The keys of rir_dataset are integers. Sort the rirs based on this # integer. rir_dataset = rir_dataset.sort(sort_fn=functools.partial(sorted, key=int)) assert len(rir_dataset) % len(source_dataset) == 0, (len(rir_dataset), len(source_dataset)) repetitions = len(rir_dataset) // len(source_dataset) source_dataset = source_dataset.sort() source_dataset = list(source_dataset.tile(repetitions)) speaker_ids = [example['speaker_id'] for example in source_dataset] rng = get_rng(dataset_name, 'example_compositions') composition_examples = None for _ in range(num_speakers): composition_examples = extend_composition_example_greedy( rng, speaker_ids, example_compositions=composition_examples, ) ex_dict = dict() assert len(rir_dataset) == len(composition_examples), ( len(rir_dataset), len(composition_examples)) for rir_example, composition_example in zip(rir_dataset, composition_examples): source_examples = [source_dataset[i] for i in composition_example] example = get_randomized_example( rir_example, source_examples, rng, dataset_name, ) ex_dict[example['example_id']] = example return ex_dict
def main(json_path: Path, rir_dir: Path, wsj_json_path: Path, num_speakers): wsj_json_path = Path(wsj_json_path).expanduser().resolve() json_path = Path(json_path).expanduser().resolve() if json_path.exists(): raise FileExistsError(json_path) rir_dir = Path(rir_dir).expanduser().resolve() assert wsj_json_path.is_file(), json_path assert rir_dir.exists(), rir_dir setup = dict( train_si284=dict(source_dataset_name="train_si284"), cv_dev93=dict(source_dataset_name="cv_dev93"), test_eval92=dict(source_dataset_name="test_eval92"), ) rir_db = JsonDatabase(rir_dir / "scenarios.json") source_db = JsonDatabase(wsj_json_path) target_db = dict() target_db['datasets'] = defaultdict(dict) for dataset_name in setup.keys(): source_dataset_name = setup[dataset_name]["source_dataset_name"] source_iterator = source_db.get_dataset(source_dataset_name) print(f'length of source {dataset_name}: {len(source_iterator)}') source_iterator = source_iterator.filter( filter_fn=filter_punctuation_pronunciation, lazy=False) print(f'length of source {dataset_name}: {len(source_iterator)} ' '(after punctuation filter)') rir_iterator = rir_db.get_dataset(dataset_name) assert len(rir_iterator) % len(source_iterator) == 0, ( f'To avoid a bias towards certain utterance the len ' f'rir_iterator ({len(rir_iterator)}) should be an integer ' f'multiple of len source_iterator ({len(source_iterator)}).') print(f'length of rir {dataset_name}: {len(rir_iterator)}') probe_path = rir_dir / dataset_name / "0" available_speaker_positions = len(list(probe_path.glob('h_*.wav'))) assert num_speakers <= available_speaker_positions, ( f'Requested {num_speakers} num_speakers, while found only ' f'{available_speaker_positions} rirs in {probe_path}.') info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav")) sample_rate_rir = info.samplerate ex_wsj = source_iterator.random_choice(1)[0] info = soundfile.SoundFile(ex_wsj['audio_path']['observation']) sample_rate_wsj = info.samplerate assert sample_rate_rir == sample_rate_wsj, (sample_rate_rir, sample_rate_wsj) rir_iterator = rir_iterator.sort( sort_fn=functools.partial(sorted, key=int)) source_iterator = source_iterator.sort() assert len(rir_iterator) % len(source_iterator) == 0 repeats = len(rir_iterator) // len(source_iterator) source_iterator = source_iterator.tile(repeats) speaker_ids = [example['speaker_id'] for example in source_iterator] rng = get_rng(dataset_name, 'example_compositions') example_compositions = None for _ in range(num_speakers): example_compositions = extend_example_composition_greedy( rng, speaker_ids, example_compositions=example_compositions, ) ex_dict = dict() assert len(rir_iterator) == len(example_compositions) for rir_example, example_composition in zip(rir_iterator, example_compositions): source_examples = source_iterator[example_composition] example_id = "_".join([ rir_example['example_id'], *[ex["example_id"] for ex in source_examples], ]) rng = get_rng(dataset_name, example_id) example = get_randomized_example( rir_example, source_examples, rng, dataset_name, rir_dir, ) ex_dict[example_id] = example target_db['datasets'][dataset_name] = ex_dict json_path.parent.mkdir(exist_ok=True, parents=True) with json_path.open('w') as f: json.dump(target_db, f, indent=2, ensure_ascii=False) print(f'{json_path} written')
def get_randomized_example(rir_example, source_examples, rng, dataset_name): example_id = "_".join([ rir_example['example_id'], *[source_ex["example_id"] for source_ex in source_examples], ]) rng = get_rng(dataset_name, example_id) example = copy(rir_example) example['example_id'] = example_id example['dataset'] = dataset_name assert len(source_examples) <= len(example['source_position'][0]) num_speakers = len(source_examples) # Remove unused source positions and rirs (i.e. the scenarios.json was # maybe generated with more speakers) example['source_position'] = [ v[:num_speakers] for v in example['source_position'] ] example['audio_path']['rir'] = example['audio_path']['rir'][:num_speakers] example['num_speakers'] = num_speakers example['speaker_id'] = [exa['speaker_id'] for exa in source_examples] # asserts that no speaker_id is used twice assert len(set(example['speaker_id'])) == example['num_speakers'] example["source_id"] = [exa['example_id'] for exa in source_examples] for k in ('gender', 'kaldi_transcription'): example[k] = [exa[k] for exa in source_examples] example['log_weights'] = rng.uniform(0, 5, size=(example['num_speakers'], )) example['log_weights'] -= np.mean(example['log_weights']) example['log_weights'] = example['log_weights'].tolist() # This way, at least the first speaker can have proper alignments, # all other speakers can not be used for ASR. def _get_num_samples(num_samples): if isinstance(num_samples, dict): return num_samples['observation'] else: return num_samples example['num_samples'] = dict() example['num_samples']['original_source'] = [ _get_num_samples(exa['num_samples']) for exa in source_examples ] example['num_samples']['observation'] = max( example['num_samples']['original_source']) example["offset"] = [] for k in range(example['num_speakers']): excess_samples = (example['num_samples']['observation'] - example['num_samples']['original_source'][k]) assert excess_samples >= 0, excess_samples example["offset"].append(rng.randint(0, excess_samples + 1)) example['audio_path']['original_source'] = [ exa['audio_path']['observation'] for exa in source_examples ] # example['audio_path']['rir']: Already defined in rir_example. return example