def combine_rirs_and_sources(
    rir_dataset,
    source_dataset,
    num_speakers,
    dataset_name,
):
    # The keys of rir_dataset are integers. Sort the rirs based on this
    # integer.
    rir_dataset = rir_dataset.sort(sort_fn=functools.partial(sorted, key=int))

    assert len(rir_dataset) % len(source_dataset) == 0, (len(rir_dataset),
                                                         len(source_dataset))
    repetitions = len(rir_dataset) // len(source_dataset)

    source_dataset = source_dataset.sort()
    source_dataset = list(source_dataset.tile(repetitions))

    speaker_ids = [example['speaker_id'] for example in source_dataset]

    rng = get_rng(dataset_name, 'example_compositions')

    composition_examples = None
    for _ in range(num_speakers):
        composition_examples = extend_composition_example_greedy(
            rng,
            speaker_ids,
            example_compositions=composition_examples,
        )

    ex_dict = dict()
    assert len(rir_dataset) == len(composition_examples), (
        len(rir_dataset), len(composition_examples))
    for rir_example, composition_example in zip(rir_dataset,
                                                composition_examples):
        source_examples = [source_dataset[i] for i in composition_example]

        example = get_randomized_example(
            rir_example,
            source_examples,
            rng,
            dataset_name,
        )
        ex_dict[example['example_id']] = example

    return ex_dict
Exemple #2
0
def main(json_path: Path, rir_dir: Path, wsj_json_path: Path, num_speakers):
    wsj_json_path = Path(wsj_json_path).expanduser().resolve()
    json_path = Path(json_path).expanduser().resolve()
    if json_path.exists():
        raise FileExistsError(json_path)
    rir_dir = Path(rir_dir).expanduser().resolve()
    assert wsj_json_path.is_file(), json_path
    assert rir_dir.exists(), rir_dir

    setup = dict(
        train_si284=dict(source_dataset_name="train_si284"),
        cv_dev93=dict(source_dataset_name="cv_dev93"),
        test_eval92=dict(source_dataset_name="test_eval92"),
    )

    rir_db = JsonDatabase(rir_dir / "scenarios.json")

    source_db = JsonDatabase(wsj_json_path)

    target_db = dict()
    target_db['datasets'] = defaultdict(dict)

    for dataset_name in setup.keys():
        source_dataset_name = setup[dataset_name]["source_dataset_name"]
        source_iterator = source_db.get_dataset(source_dataset_name)
        print(f'length of source {dataset_name}: {len(source_iterator)}')
        source_iterator = source_iterator.filter(
            filter_fn=filter_punctuation_pronunciation, lazy=False)
        print(f'length of source {dataset_name}: {len(source_iterator)} '
              '(after punctuation filter)')

        rir_iterator = rir_db.get_dataset(dataset_name)

        assert len(rir_iterator) % len(source_iterator) == 0, (
            f'To avoid a bias towards certain utterance the len '
            f'rir_iterator ({len(rir_iterator)}) should be an integer '
            f'multiple of len source_iterator ({len(source_iterator)}).')

        print(f'length of rir {dataset_name}: {len(rir_iterator)}')

        probe_path = rir_dir / dataset_name / "0"
        available_speaker_positions = len(list(probe_path.glob('h_*.wav')))
        assert num_speakers <= available_speaker_positions, (
            f'Requested {num_speakers} num_speakers, while found only '
            f'{available_speaker_positions} rirs in {probe_path}.')

        info = soundfile.info(str(rir_dir / dataset_name / "0" / "h_0.wav"))
        sample_rate_rir = info.samplerate

        ex_wsj = source_iterator.random_choice(1)[0]
        info = soundfile.SoundFile(ex_wsj['audio_path']['observation'])
        sample_rate_wsj = info.samplerate
        assert sample_rate_rir == sample_rate_wsj, (sample_rate_rir,
                                                    sample_rate_wsj)

        rir_iterator = rir_iterator.sort(
            sort_fn=functools.partial(sorted, key=int))

        source_iterator = source_iterator.sort()
        assert len(rir_iterator) % len(source_iterator) == 0
        repeats = len(rir_iterator) // len(source_iterator)
        source_iterator = source_iterator.tile(repeats)

        speaker_ids = [example['speaker_id'] for example in source_iterator]

        rng = get_rng(dataset_name, 'example_compositions')

        example_compositions = None
        for _ in range(num_speakers):
            example_compositions = extend_example_composition_greedy(
                rng,
                speaker_ids,
                example_compositions=example_compositions,
            )

        ex_dict = dict()
        assert len(rir_iterator) == len(example_compositions)
        for rir_example, example_composition in zip(rir_iterator,
                                                    example_compositions):
            source_examples = source_iterator[example_composition]

            example_id = "_".join([
                rir_example['example_id'],
                *[ex["example_id"] for ex in source_examples],
            ])

            rng = get_rng(dataset_name, example_id)
            example = get_randomized_example(
                rir_example,
                source_examples,
                rng,
                dataset_name,
                rir_dir,
            )
            ex_dict[example_id] = example

        target_db['datasets'][dataset_name] = ex_dict

    json_path.parent.mkdir(exist_ok=True, parents=True)
    with json_path.open('w') as f:
        json.dump(target_db, f, indent=2, ensure_ascii=False)
    print(f'{json_path} written')
def get_randomized_example(rir_example, source_examples, rng, dataset_name):
    example_id = "_".join([
        rir_example['example_id'],
        *[source_ex["example_id"] for source_ex in source_examples],
    ])
    rng = get_rng(dataset_name, example_id)

    example = copy(rir_example)
    example['example_id'] = example_id
    example['dataset'] = dataset_name

    assert len(source_examples) <= len(example['source_position'][0])
    num_speakers = len(source_examples)

    # Remove unused source positions and rirs (i.e. the scenarios.json was
    # maybe generated with more speakers)
    example['source_position'] = [
        v[:num_speakers] for v in example['source_position']
    ]
    example['audio_path']['rir'] = example['audio_path']['rir'][:num_speakers]

    example['num_speakers'] = num_speakers

    example['speaker_id'] = [exa['speaker_id'] for exa in source_examples]

    # asserts that no speaker_id is used twice
    assert len(set(example['speaker_id'])) == example['num_speakers']

    example["source_id"] = [exa['example_id'] for exa in source_examples]

    for k in ('gender', 'kaldi_transcription'):
        example[k] = [exa[k] for exa in source_examples]

    example['log_weights'] = rng.uniform(0,
                                         5,
                                         size=(example['num_speakers'], ))
    example['log_weights'] -= np.mean(example['log_weights'])
    example['log_weights'] = example['log_weights'].tolist()

    # This way, at least the first speaker can have proper alignments,
    # all other speakers can not be used for ASR.
    def _get_num_samples(num_samples):
        if isinstance(num_samples, dict):
            return num_samples['observation']
        else:
            return num_samples

    example['num_samples'] = dict()
    example['num_samples']['original_source'] = [
        _get_num_samples(exa['num_samples']) for exa in source_examples
    ]
    example['num_samples']['observation'] = max(
        example['num_samples']['original_source'])

    example["offset"] = []
    for k in range(example['num_speakers']):
        excess_samples = (example['num_samples']['observation'] -
                          example['num_samples']['original_source'][k])
        assert excess_samples >= 0, excess_samples
        example["offset"].append(rng.randint(0, excess_samples + 1))

    example['audio_path']['original_source'] = [
        exa['audio_path']['observation'] for exa in source_examples
    ]
    # example['audio_path']['rir']: Already defined in rir_example.
    return example