Example #1
0
def cross_communication():
    print(f'cross_communication test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    dlp_mpi.barrier()
    if RANK == 1:
        time.sleep(0.1)
    elif RANK == 2:
        pass

    results = []
    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)

        if RANK == 1:
            results.append(i)
        elif RANK == 2:
            results.append(i)
            time.sleep(0.2)

    if RANK == 1:
        assert results in [[0, 2, 3, 4], [1, 2, 3, 4]], results
    elif RANK == 2:
        assert results in [[1], [0]], results

    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
        if RANK == 1:
            results.append(i)
            time.sleep(0.001)
        elif RANK == 2:
            results.append(i)
            time.sleep(0.2)

    if RANK == 1:
        assert results in [
            [0, 2, 3, 4, 0, 2, 3, 4],
            [0, 2, 3, 4, 1, 2, 3, 4],
            [1, 2, 3, 4, 0, 2, 3, 4],
            [1, 2, 3, 4, 1, 2, 3, 4],
        ], results
    elif RANK == 2:
        assert results in [
            [1, 1],
            [1, 0],
            [0, 1],
            [0, 0],
        ], results
Example #2
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset in datasets:
            iterable = prepare_iterable(
                db,
                dataset,
                batch_size,
                return_keys=None,
                prefetch=False,
            )

            for batch in dlp_mpi.split_managed(iterable,
                                               is_indexable=False,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(pt.data.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]
                entry['metrics'] \
                    = pb_bss.evaluation.OutputMetrics(speech_prediction=z,
                                                      speech_source=s).as_dict()

        summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        print(f'len(summary_list): {len(summary_list)}')
        for partial_summary in summary_list:
            for dataset, values in partial_summary.items():
                summary[dataset].update(values)

        for dataset, values in summary.items():
            print(f'{dataset}: {len(values)}')

        result_json_path = experiment_dir / 'result.json'
        print(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)
Example #3
0
def worker_fails():
    print(f'worker_fails test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    processed = []
    try:
        dlp_mpi.barrier()
        if RANK == 2:
            # Delay rank 2, this ensures that rank 1 gets the first example
            # Does no longer work, becasue in split_managed is COMM.Clone
            # used.
            time.sleep(0.1)
        for i in dlp_mpi.split_managed(examples, progress_bar=False):
            processed.append(i)
            if RANK == 1:
                print(f'let {RANK} fail for data {i}')
                raise ValueError('failed')
            assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
    except ValueError:
        assert RANK in [1], RANK
        assert processed in [[0], [1]], processed
    except AssertionError:
        assert RANK in [0], RANK
        assert processed == [], processed
    else:
        assert RANK in [2], RANK
        assert processed == [1, 2, 3, 4], processed
Example #4
0
    def enhance_session(self,
                        session_ids,
                        audio_dir,
                        dataset_slice=False,
                        audio_dir_exist_ok=False):
        """

        Args:
            session_ids:
            audio_dir:
            dataset_slice:
            audio_dir_exist_ok:
                When True: It is ok, when the audio dir exists and the files
                insinde may be overwritten.

        Returns:


        >>> enhancer = get_enhancer(wpe=False, bss_iterations=2)
        >>> for x_hat in enhancer.enhance_session('S02'):
        ...     print(x_hat)
        """
        ensure_single_thread_numeric()

        audio_dir = Path(audio_dir)

        it = self.get_iterator(session_ids)

        if dlp_mpi.IS_MASTER:
            audio_dir.mkdir(exist_ok=audio_dir_exist_ok)

            for dataset in set(mapping.session_to_dataset.values()):
                (audio_dir / dataset).mkdir(exist_ok=audio_dir_exist_ok)

        dlp_mpi.barrier()

        if dataset_slice is not False:
            if dataset_slice is True:
                it = it[:2]
            elif isinstance(dataset_slice, int):
                it = it[:dataset_slice]
            elif isinstance(dataset_slice, slice):
                it = it[dataset_slice]
            else:
                raise ValueError(dataset_slice)

        for ex in dlp_mpi.split_managed(it, allow_single_worker=True):
            x_hat = self.enhance_example(ex)
            example_id = ex["example_id"]
            session_id = ex["session_id"]
            dataset = mapping.session_to_dataset[session_id]

            if x_hat.ndim == 1:
                save_path = audio_dir / f'{dataset}' / f'{example_id}.wav'
                dump_audio(
                    x_hat,
                    save_path,
                )
            else:
                raise NotImplementedError(x_hat.shape)
Example #5
0
def rirs(
    database_path,
    datasets,
    sample_rate,
    filter_length,
):
    database_path = Path(database_path)

    if dlp_mpi.IS_MASTER:
        scenario_json = database_path / "scenarios.json"
        with scenario_json.open() as f:
            database = json.load(f)
        for dataset in datasets:
            dataset_path = database_path / dataset
            dataset_path.mkdir(parents=True, exist_ok=True)
    else:
        database = None
    database = dlp_mpi.bcast(database)

    for dataset_name, dataset in database['datasets'].items():
        print(f'RANK={dlp_mpi.RANK}, SIZE={dlp_mpi.SIZE}:'
              f' Starting {dataset_name}.')

        for _example_id, example in dlp_mpi.split_managed(
                list(sorted(dataset.items())),
                progress_bar=True,
                is_indexable=True,
        ):
            h = generate_rir(room_dimensions=example['room_dimensions'],
                             source_positions=example['source_position'],
                             sensor_positions=example['sensor_position'],
                             sound_decay_time=example['sound_decay_time'],
                             sample_rate=sample_rate,
                             filter_length=filter_length,
                             sensor_orientations=None,
                             sensor_directivity=None,
                             sound_velocity=343)
            assert not np.any(np.isnan(
                h)), f"{np.sum(np.isnan(h))} values of {h.size} are NaN."

            K, D, T = h.shape
            directory = database_path / dataset_name / _example_id
            directory.mkdir(parents=False, exist_ok=False)

            for k in range(K):
                # Although storing as np.float64 does not allow every reader
                # to access the files, it does not require normalization and
                # we are unsure how much precision is needed for RIRs.
                with soundfile.SoundFile(str(directory / f"h_{k}.wav"),
                                         subtype='DOUBLE',
                                         samplerate=sample_rate,
                                         mode='w',
                                         channels=h.shape[1]) as f:
                    f.write(h[k, :, :].T)

        dlp_mpi.barrier()

        print(f'RANK={dlp_mpi.RANK}, SIZE={dlp_mpi.SIZE}:'
              f' Finished {dataset_name}.')
Example #6
0
def overhead():
    """

    This test shows the overhead of map_unordered.
    A simple for loop with map cam process around 1 000 000 examples per
    second. When using map_unordered, obviously the number should decrease.

    When your code processes less than 1000 examples per second, you can expect
    a gain from map_unordered. When you process serial more than 10000 examples
    per second, it is unlikely to get a gain from map_unordered.
    Thing about chunking to get less than 100 example per second.
    """
    print(f'executable test {RANK}')

    examples = list(range(10000))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    total = 0
    # def bar(i):
    #     nonlocal total
    #     total += i
    #     return i

    start = time.perf_counter()
    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        total += i
    elapsed = time.perf_counter() - start

    if RANK == 0:
        assert total == 0, total
    elif RANK == 1:
        assert total > 10_000_000, total
    elif RANK == 2:
        assert total > 10_000_000, total
    else:
        raise ValueError(RANK)

    time_per_example = elapsed / 1000
    mpi_examples_per_second = 1 / time_per_example

    assert mpi_examples_per_second >= 10_000, mpi_examples_per_second
    assert mpi_examples_per_second <= 300_000, mpi_examples_per_second

    print('split_managed examples/second =', mpi_examples_per_second)

    start = time.perf_counter()
    for i in examples:
        total += i
    elapsed = time.perf_counter() - start

    time_per_example = elapsed / 1000
    py_examples_per_second = 1 / time_per_example

    assert py_examples_per_second >= 250_000, py_examples_per_second
    assert py_examples_per_second <= 9_000_000, py_examples_per_second
Example #7
0
def executable():
    print(f'executable test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
def main(_run, out):
    if dlp_mpi.IS_MASTER:
        from sacred.commands import print_config
        print_config(_run)

    ds = get_dataset()

    data = []

    for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True):
        for prediction in [
                'source',
                'early_0',
                'early_1',
                'image_0',
                'image_1',
                'image_0_noise',
                'image_1_noise',
        ]:
            for source in [
                    'source',
                    'early_0',
                    'early_1',
                    'image_0',
                    'image_1',
                    'image_0_noise',
                    'image_1_noise',
            ]:
                scores = get_scores(ex, prediction=prediction, source=source)
                for score_name, score_value in scores.items():
                    data.append(
                        dict(
                            score_name=score_name,
                            prediction=prediction,
                            source=source,
                            example_id=ex['example_id'],
                            value=score_value,
                        ))

    data = dlp_mpi.gather(data)

    if dlp_mpi.IS_MASTER:
        data = [entry for worker_data in data for entry in worker_data]

        if out is not None:
            assert isinstance(out, str), out
            assert out.endswith('.json'), out
            print(f'Write details to {out}.')
            dump_json(data, out)

        summary(data)
Example #9
0
def pbar():
    print(f'executable test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    class MockPbar:
        call_history = []

        def __init__(self):
            self.i = 0

        def set_description(self, text):
            self.call_history.append(text)

        def update(self, inc=1):
            self.i += 1
            self.call_history.append(f'update {self.i}')

    import contextlib

    @contextlib.contextmanager
    def mock_pbar(total, mininterval, smoothing):
        yield MockPbar()

    import mock

    with mock.patch('tqdm.tqdm', mock_pbar):

        dlp_mpi.barrier()
        if RANK == 2:
            time.sleep(0.02)

        for i in dlp_mpi.split_managed(examples):
            time.sleep(0.04)
            assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)

    if RANK == 0:
        assert MockPbar.call_history == [
            'busy: 2', 'update 1', 'update 2', 'update 3', 'update 4',
            'busy: 1', 'update 5', 'busy: 0'
        ], MockPbar.call_history
    else:
        assert MockPbar.call_history == [], MockPbar.call_history
Example #10
0
def evaluate_model(dataset, model, get_sad_fn,
                   get_target_fn=lambda x: x['activation'],
                   num_thresholds=201, buffer_zone=0.5,
                   is_indexable=True, allow_single_worker=True,
                   sample_rate=8000):

    tp_fp_tn_fn = np.zeros((num_thresholds, 4), dtype=int)

    import dlp_mpi
    for example in dlp_mpi.split_managed(
        dataset, is_indexable=is_indexable,
        allow_single_worker=allow_single_worker,
    ):
        target = get_target_fn(example)
        adjusted_target = adjust_annotation_fn(
            target, buffer_zone=buffer_zone,
            sample_rate=sample_rate
        )
        model_out = model(example)
        for idx, th in enumerate(np.linspace(0, 1, num_thresholds)):
            th = np.round(th, 2)
            sad = get_sad_fn(model_out, th, example)
            out = get_tp_fp_tn_fn(
                adjusted_target, sad,
                sample_rate=sample_rate, adjust_annotation=False
            )
            tp_fp_tn_fn[idx] = [tp_fp_tn_fn[idx][idy] + o for idy, o in
                                enumerate(out)]

    dlp_mpi.barrier()
    tp_fp_tn_fn_gather = dlp_mpi.gather(tp_fp_tn_fn, root=dlp_mpi.MASTER)
    if dlp_mpi.IS_MASTER:
        tp_fp_tn_fn = np.zeros((num_thresholds, 4), dtype=int)
        for array in tp_fp_tn_fn_gather:
            tp_fp_tn_fn += array
    else:
        tp_fp_tn_fn = None
    return tp_fp_tn_fn
Example #11
0
def speedup():
    print(f'speedup test {RANK}')
    examples = list(range(4))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    sleep_time = 0.1

    start = time.perf_counter()
    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        time.sleep(sleep_time)
        # print(f'Callback from {RANK}')
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
    elapsed = time.perf_counter() - start

    serial_time = sleep_time * len(examples)

    # Two workers, one manager (3 mpi processes) reduce the time by 0.5
    # Consider some python overhead
    assert elapsed < 0.6 * serial_time, (elapsed, serial_time)
    assert elapsed >= 0.5 * serial_time, (elapsed, serial_time)
Example #12
0
def write_wavs(dst_dir: Path, wsj0_root: Path, wsj1_root: Path, sample_rate):
    wsj0_root = Path(wsj0_root).expanduser().resolve()
    wsj1_root = Path(wsj1_root).expanduser().resolve()
    dst_dir = Path(dst_dir).expanduser().resolve()
    assert wsj0_root.exists(), wsj0_root
    assert wsj1_root.exists(), wsj1_root

    assert not dst_dir == wsj0_root, (wsj0_root, dst_dir)
    assert not dst_dir == wsj1_root, (wsj1_root, dst_dir)
    # Expect, that the dst_dir does not exist to make sure to not overwrite.
    if dlp_mpi.IS_MASTER:
        dst_dir.mkdir(parents=True, exist_ok=False)

    if dlp_mpi.IS_MASTER:
        cds_0 = list(wsj0_root.rglob("*-*.*"))
        cds_1 = list(wsj1_root.rglob("*-*.*"))
        cds = set(cds_0 + cds_1)
        for suffix in 'pl ndx ptx dot txt'.split():
            files_0 = list(wsj0_root.rglob(f"*.{suffix}"))
            files_1 = list(wsj1_root.rglob(f"*.{suffix}"))
            files = set(files_0 + files_1)
            # the readme.txt file in the parent directory is not copied
            print(f"About to write ca. {len(files)} {suffix} files.")
            for cd in cds:
                cd_files = list(cd.rglob(f"*.{suffix}"))
                for file in cd_files:
                    target = dst_dir / file.relative_to(cd.parent)
                    target.parent.mkdir(parents=True, exist_ok=True)
                    if not target.is_file():
                        shutil.copy(file, target.parent)
            written_files = list(dst_dir.rglob(f"*.{suffix}"))
            print(f"Writing {len(written_files)} {suffix} files.")
            # assert len(written_files) == len(files), (files, written_files)

    if dlp_mpi.IS_MASTER:
        # Ignore .wv2 files since they are not referenced in our database
        # anyway
        wsj_nist_files = [(cd, nist_file) for cd in cds
                          for nist_file in cd.rglob("*.wv1")]
        print(f"About to write {len(wsj_nist_files)} wav files.")
    else:
        wsj_nist_files = None

    wsj_nist_files = dlp_mpi.bcast(wsj_nist_files)

    for nist_file_tuple in dlp_mpi.split_managed(wsj_nist_files):
        cd, nist_file = nist_file_tuple
        assert isinstance(nist_file, Path), nist_file
        signal = read_nist_wsj(nist_file, expected_sample_rate=16000)
        file = nist_file.with_suffix('.wav')
        target = dst_dir / file.relative_to(cd.parent)
        assert not target == nist_file, (nist_file, target)
        target.parent.mkdir(parents=True, exist_ok=True)
        # normalization:
        #   Correction, because the allowed values are in the range [-1, 1).
        #       => "1" is not a vaild value
        signal = resample_with_sox(signal, rate_in=16000, rate_out=sample_rate)
        correction = (2 ** 15 - 1) / (2 ** 15)
        signal = signal * (correction / np.amax(np.abs(signal)))
        with soundfile.SoundFile(
                str(target), samplerate=sample_rate, channels=1,
                subtype='FLOAT', mode='w',
        ) as f:
            f.write(signal.T)

    dlp_mpi.barrier()
    if dlp_mpi.IS_MASTER:
        created_files = list(set(list(dst_dir.rglob("*.wav"))))
        print(f"Written {len(created_files)} wav files.")
        assert len(wsj_nist_files) == len(created_files), (len(wsj_nist_files), len(created_files))
Example #13
0
def main(
    _run,
    out,
    mask_estimator,
    Observation,
    beamformer,
    postfilter,
    normalize_audio=True,
):
    if dlp_mpi.IS_MASTER:
        from sacred.commands import print_config
        print_config(_run)

    ds = get_dataset()

    data = []

    out = Path(out)

    for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True):

        if mask_estimator is None:
            mask = None
        elif mask_estimator == 'cacgmm':
            mask = get_mask_from_cacgmm(ex)
        else:
            mask = get_mask_from_oracle(ex, mask_estimator)

        metric, score = get_scores(
            ex,
            mask,
            Observation=Observation,
            beamformer=beamformer,
            postfilter=postfilter,
        )

        est0, est1 = metric.speech_prediction_selection
        dump_audio(est0,
                   out / ex['dataset'] / f"{ex['example_id']}_0.wav",
                   normalize=normalize_audio)
        dump_audio(est1,
                   out / ex['dataset'] / f"{ex['example_id']}_1.wav",
                   normalize=normalize_audio)

        data.append(
            dict(
                example_id=ex['example_id'],
                value=score,
                dataset=ex['dataset'],
            ))

        # print(score, repr(score))

    data = dlp_mpi.gather(data)

    if dlp_mpi.IS_MASTER:
        data = [entry for worker_data in data for entry in worker_data]

        data = {  # itertools.groupby expect an order
            dataset: list(subset)
            for dataset, subset in from_list(data).groupby(
                lambda ex: ex['dataset']).items()
        }

        for dataset, sub_data in data.items():
            print(f'Write details to {out}.')
            dump_json(sub_data, out / f'{dataset}_scores.json')

        for dataset, sub_data in data.items():
            summary = {}
            for k in sub_data[0]['value'].keys():
                m = np.mean([d['value'][k] for d in sub_data])
                print(dataset, k, m)
                summary[k] = m
            dump_json(summary, out / f'{dataset}_summary.json')
Example #14
0
def write_wavs(dst_dir, json_path, write_all=False, snr_range=(20, 30)):
    db = JsonDatabase(json_path)
    if write_all:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type).mkdir(exist_ok=False)
             for data_type in KEY_MAPPER.values()]
        map_fn = partial(scenario_map_fn,
                         snr_range=snr_range,
                         sync_speech_source=True,
                         add_speech_reverberation_early=True,
                         add_speech_reverberation_tail=True)
    else:
        if dlp_mpi.IS_MASTER:
            (dst_dir / 'observation').mkdir(exist_ok=False)
        map_fn = partial(scenario_map_fn,
                         snr_range=snr_range,
                         sync_speech_source=True,
                         add_speech_reverberation_early=False,
                         add_speech_reverberation_tail=False)
    for dataset in ['train_si284', 'cv_dev93', 'test_eval92']:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type / dataset).mkdir(exist_ok=False)
             for data_type in KEY_MAPPER.values()]
        ds = db.get_dataset(dataset).map(audio_read).map(map_fn)
        for example in dlp_mpi.split_managed(
                ds,
                is_indexable=True,
                allow_single_worker=True,
                progress_bar=True,
        ):
            audio_dict = example['audio_data']
            example_id = example['example_id']
            if not write_all:
                del audio_dict['speech_reverberation_early']
                del audio_dict['speech_reverberation_tail']
                del audio_dict['noise_image']

            def get_abs_max(a):
                if isinstance(a, np.ndarray):
                    if a.dtype == np.object:
                        return np.max(list(map(get_abs_max, a)))
                    else:
                        return np.max(np.abs(a))
                elif isinstance(a, (tuple, list)):
                    return np.max(list(map(get_abs_max, a)))
                elif isinstance(a, dict):
                    return np.max(list(map(get_abs_max, a.values())))
                else:
                    raise TypeError(a)

            assert get_abs_max(audio_dict), (example_id, {
                k: get_abs_max(v)
                for k, v in audio_dict.items()
            })
            for key, value in audio_dict.items():
                if key not in KEY_MAPPER:
                    continue
                path = dst_dir / KEY_MAPPER[key] / dataset
                if key in ['observation', 'noise_image']:
                    value = value[None]
                for idx, signal in enumerate(value):
                    appendix = f'_{idx}' if len(value) > 1 else ''
                    filename = example_id + appendix + '.wav'
                    audio_path = str(path / filename)
                    with soundfile.SoundFile(audio_path,
                                             subtype='FLOAT',
                                             mode='w',
                                             samplerate=8000,
                                             channels=1 if signal.ndim == 1
                                             else signal.shape[0]) as f:
                        f.write(signal.T)

        dlp_mpi.barrier()

    if dlp_mpi.IS_MASTER:
        created_files = check_files(dst_dir)
        print(f"Written {len(created_files)} wav files.")
        if write_all:
            # TODO Less, if you do a test run.
            num_speakers = 2  # todo infer num_speakers from json
            # 2 files for: early, tail, speech_source
            # 1 file for: observation, noise
            expect = (3 * num_speakers + 2) * 35875
            assert len(created_files) == expect, (len(created_files), expect)
        else:
            assert len(created_files) == 35875, len(created_files)
Example #15
0
def main(_run, datasets, debug, experiment_dir, dump_audio, sample_rate, _log,
         database_json):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)
        dump_config_and_makefile()

    model = get_model()
    db = JsonDatabase(database_json)

    model.eval()
    results = defaultdict(dict)
    with torch.no_grad():
        for dataset in datasets:
            iterable = prepare_dataset(
                db,
                dataset,
                1,
                chunk_size=-1,
                prefetch=False,
                shuffle=False,
            )

            if dump_audio:
                (experiment_dir / 'audio' / dataset).mkdir(parents=True,
                                                           exist_ok=True)

            for batch in dlp_mpi.split_managed(iterable,
                                               is_indexable=True,
                                               allow_single_worker=True):
                example_id = batch['example_id'][0]
                results[dataset][example_id] = entry = dict()

                try:
                    model_output = model(model.example_to_device(batch))

                    # Bring to numpy float64 for evaluation metrics computation
                    observation = batch['y'][0].astype(np.float64)[None, ]
                    speech_prediction = (
                        model_output['out'][0].cpu().numpy().astype(
                            np.float64))
                    speech_source = batch['s'][0].astype(np.float64)

                    input_metrics = pb_bss.evaluation.InputMetrics(
                        observation=observation,
                        speech_source=speech_source,
                        sample_rate=sample_rate,
                        enable_si_sdr=True,
                    )

                    output_metrics = pb_bss.evaluation.OutputMetrics(
                        speech_prediction=speech_prediction,
                        speech_source=speech_source,
                        sample_rate=sample_rate,
                        enable_si_sdr=True,
                    )

                    # Select the metrics to compute
                    entry['input'] = dict(
                        mir_eval=input_metrics.mir_eval,
                        si_sdr=input_metrics.si_sdr,
                        # TODO: stoi fails with short speech segments (https://github.com/mpariente/pystoi/issues/21)
                        # stoi=input_metrics.stoi,
                        # TODO: pesq creates "Processing error" messages
                        # pesq=input_metrics.pesq,
                    )

                    # Remove selection from mir_eval dict to enable
                    # recursive calculation of improvement
                    entry['output'] = dict(
                        mir_eval={
                            k: v
                            for k, v in output_metrics.mir_eval.items()
                            if k != 'selection'
                        },
                        si_sdr=output_metrics.si_sdr,
                        # stoi=output_metrics.stoi,
                        # pesq=output_metrics.pesq,
                    )

                    entry['improvement'] = pb.utils.nested.nested_op(
                        operator.sub,
                        entry['output'],
                        entry['input'],
                    )
                    entry['selection'] = output_metrics.mir_eval['selection']

                    if dump_audio:
                        entry['audio_path'] = batch['audio_path']
                        entry['audio_path']['estimated'] = []
                        for k, audio in enumerate(speech_prediction):
                            audio_path = (experiment_dir / 'audio' / dataset /
                                          f'{example_id}_{k}.wav')
                            pb.io.dump_audio(audio,
                                             audio_path,
                                             sample_rate=sample_rate)
                            entry['audio_path']['estimated'].append(audio_path)
                except:
                    _log.error(f'Exception was raised in example with ID '
                               f'"{example_id}"')
                    raise

    results = dlp_mpi.gather(results, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        # Combine all results to one. This function raises an exception if it
        # finds duplicate keys
        results = pb.utils.nested.nested_merge(*results)

        for dataset, values in results.items():
            _log.info(f'{dataset}: {len(values)}')

        # Write results to JSON
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(results, result_json_path)

        # Compute means for some metrics
        means = compute_means(results)
        mean_json_path = experiment_dir / 'means.json'
        _log.info(f'Exporting means: {mean_json_path}')
        pb.io.dump_json(means, mean_json_path)

        _log.info('Resulting means:')

        pprint(means)
Example #16
0
    def enhance_session(
        self,
        session_ids,
        audio_dir,
        dataset_slice=False,
        audio_dir_exist_ok=False,
        is_chime=True,
    ):
        """

        Args:
            session_ids:
            audio_dir:
            dataset_slice:
            audio_dir_exist_ok:
                When True: It is ok, when the audio dir exists and the files
                insinde may be overwritten.
            is_chime:
                If true, map the session_id to the dataset name for the folder
                naming. Otherwise keep the session_id for the folder name.
        Returns:


        >>> enhancer = get_enhancer(wpe=False, bss_iterations=2)
        >>> for x_hat in enhancer.enhance_session('S02'):
        ...     print(x_hat)
        """
        ensure_single_thread_numeric()

        audio_dir = Path(audio_dir)

        it = self.get_dataset(session_ids)

        if dlp_mpi.IS_MASTER:
            audio_dir.mkdir(exist_ok=audio_dir_exist_ok)
            # for dataset in self.db.data['alias']:
            #     (audio_dir / dataset).mkdir(exist_ok=audio_dir_exist_ok)

        dlp_mpi.barrier()

        if dataset_slice is not False:
            if dataset_slice is True:
                it = it[:2]
            elif isinstance(dataset_slice, int):
                it = it[:dataset_slice]
            elif isinstance(dataset_slice, slice):
                it = it[dataset_slice]
            else:
                raise ValueError(dataset_slice)

        for ex in dlp_mpi.split_managed(it, allow_single_worker=True):
            try:
                x_hat = self.enhance_example(ex)
                example_id = ex["example_id"]
                session_id = ex["session_id"]
                if is_chime:
                    dataset = mapping.session_to_dataset[session_id]
                else:
                    dataset = session_id

                if x_hat.ndim == 1:
                    save_path = audio_dir / f'{dataset}' / f'{example_id}.wav'
                    dump_audio(x_hat, save_path, mkdir=True)
                else:
                    raise NotImplementedError(x_hat.shape)
            except Exception:
                print('ERROR: Failed example:', ex['example_id'])
                raise
Example #17
0
def evaluate(checkpoint_path, eval_dir, database_json):
    model = SimpleMaskEstimator(513)

    model.load_checkpoint(
        checkpoint_path=checkpoint_path,
        in_checkpoint_path='model',
        consider_mpi=True
    )
    model.eval()
    if dlp_mpi.IS_MASTER:
        print(f'Start to evaluate the checkpoint {checkpoint_path.resolve()} '
              f'and will write the evaluation result to'
              f' {eval_dir / "result.json"}')
    database = JsonDatabase(database_json)
    test_dataset = get_test_dataset(database)
    with torch.no_grad():
        summary = dict(masked=dict(), beamformed=dict(), observed=dict())
        for batch in dlp_mpi.split_managed(
                test_dataset, is_indexable=True,
                progress_bar=True,
                allow_single_worker=True
        ):
            model_output = model(pt.data.example_to_device(batch))

            example_id = batch['example_id']
            s = batch['speech_source'][0][None]

            speech_mask = model_output['speech_mask_prediction'].numpy()
            Y = batch['observation_stft']
            Z_mask = speech_mask[0] * Y[0]
            z_mask = pb.transform.istft(Z_mask)[None]

            speech_mask = np.median(speech_mask, axis=0).T
            noise_mask = model_output['noise_mask_prediction'].numpy()
            noise_mask = np.median(noise_mask, axis=0).T
            Y = rearrange(Y, 'c t f -> f c t')
            target_psd = pb_bss.extraction.get_power_spectral_density_matrix(
                Y, speech_mask,
            )
            noise_psd = pb_bss.extraction.get_power_spectral_density_matrix(
                Y, noise_mask,
            )
            beamformer = pb_bss.extraction.get_bf_vector(
                'mvdr_souden',
                target_psd_matrix=target_psd,
                noise_psd_matrix=noise_psd

            )
            Z_bf = pb_bss.extraction.apply_beamforming_vector(beamformer, Y).T
            z_bf = pb.transform.istft(Z_bf)[None]

            y = batch['observation'][0][None]
            s = s[:, :z_bf.shape[1]]
            for key, signal in zip(summary.keys(), [z_mask, z_bf, y]):
                signal = signal[:, :s.shape[1]]
                entry = pb_bss.evaluation.OutputMetrics(
                    speech_prediction=signal, speech_source=s,
                    sample_rate=16000
                ).as_dict()
                entry.pop('mir_eval_selection')
                summary[key][example_id] = entry

    summary_list = dlp_mpi.COMM.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        print(f'\n len(summary_list): {len(summary_list)}')
        summary = dict(masked=dict(), beamformed=dict(), observed=dict())
        for partial_summary in summary_list:
            for signal_type, metric in partial_summary.items():
                summary[signal_type].update(metric)
        for signal_type, values in summary.items():
            print(signal_type)
            for metric in next(iter(values.values())).keys():
                mean = np.mean([value[metric] for key, value in values.items()
                                if '_mean' not in key])
                values[metric + '_mean'] = mean
                print(f'{metric}: {mean}')

        result_json_path = eval_dir / 'result.json'
        print(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)
Example #18
0
def main(_run, exp_dir, storage_dir, database_json, test_set, max_examples,
         device):
    if IS_MASTER:
        commands.print_config(_run)

    exp_dir = Path(exp_dir)
    storage_dir = Path(storage_dir)
    audio_dir = storage_dir / 'audio'
    audio_dir.mkdir(parents=True)

    config = load_json(exp_dir / 'config.json')

    model = Model.from_storage_dir(exp_dir, consider_mpi=True)
    model.to(device)
    model.eval()

    db = JsonDatabase(database_json)
    test_data = db.get_dataset(test_set)
    if max_examples is not None:
        test_data = test_data.shuffle(
            rng=np.random.RandomState(0))[:max_examples]
    test_data = prepare_dataset(test_data,
                                audio_reader=config['audio_reader'],
                                stft=config['stft'],
                                max_length=None,
                                batch_size=1,
                                shuffle=True)
    squared_err = list()
    with torch.no_grad():
        for example in split_managed(test_data,
                                     is_indexable=False,
                                     progress_bar=True,
                                     allow_single_worker=True):
            example = model.example_to_device(example, device)
            target = example['audio_data'].squeeze(1)
            x = model.feature_extraction(example['stft'], example['seq_len'])
            x = model.wavenet.infer(
                x.squeeze(1),
                chunk_length=80_000,
                chunk_overlap=16_000,
            )
            assert target.shape == x.shape, (target.shape, x.shape)
            squared_err.extend([(ex_id, mse.cpu().detach().numpy(), x.shape[1])
                                for ex_id, mse in zip(example['example_id'], ((
                                    x - target)**2).sum(1))])

    squared_err_list = COMM.gather(squared_err, root=MASTER)

    if IS_MASTER:
        print(f'\nlen(squared_err_list): {len(squared_err_list)}')
        squared_err = []
        for i in range(len(squared_err_list)):
            squared_err.extend(squared_err_list[i])
        _, err, t = list(zip(*squared_err))
        print('rmse:', np.sqrt(np.sum(err) / np.sum(t)))
        rmse = sorted([(ex_id, np.sqrt(err / t))
                       for ex_id, err, t in squared_err],
                      key=lambda x: x[1])
        dump_json(rmse, storage_dir / 'rmse.json', indent=4, sort_keys=False)
        ex_ids_ordered = [x[0] for x in rmse]
        test_data = db.get_dataset('test_clean').shuffle(
            rng=np.random.RandomState(0))[:max_examples].filter(lambda x: x[
                'example_id'] in ex_ids_ordered[:10] + ex_ids_ordered[-10:],
                                                                lazy=False)
        test_data = prepare_dataset(test_data,
                                    audio_reader=config['audio_reader'],
                                    stft=config['stft'],
                                    max_length=10.,
                                    batch_size=1,
                                    shuffle=True)
        with torch.no_grad():
            for example in test_data:
                example = model.example_to_device(example, device)
                x = model.feature_extraction(example['stft'],
                                             example['seq_len'])
                x = model.wavenet.infer(
                    x.squeeze(1),
                    chunk_length=80_000,
                    chunk_overlap=16_000,
                )
                for i, audio in enumerate(x.cpu().detach().numpy()):
                    wavfile.write(
                        str(audio_dir / f'{example["example_id"][i]}.wav'),
                        model.sample_rate, audio)
Example #19
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json,
         _log):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset_name in datasets:
            dataset = prepare_dataset(db,
                                      dataset_name,
                                      batch_size,
                                      return_keys=None,
                                      prefetch=False,
                                      shuffle=False)

            for batch in dlp_mpi.split_managed(dataset,
                                               is_indexable=True,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(model.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]

                input_metrics = pb_bss.evaluation.InputMetrics(
                    observation=batch['y'][0][None, :],
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )

                output_metrics = pb_bss.evaluation.OutputMetrics(
                    speech_prediction=z,
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )
                entry['input'] = dict(mir_eval=input_metrics.mir_eval, )
                entry['output'] = dict(mir_eval={
                    k: v
                    for k, v in output_metrics.mir_eval.items()
                    if k != 'selection'
                }, )

                entry['improvement'] = pb.utils.nested.nested_op(
                    operator.sub,
                    entry['output'],
                    entry['input'],
                )
                entry['selection'] = output_metrics.mir_eval['selection']

                summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        _log.info(f'len(summary_list): {len(summary_list)}')
        summary = pb.utils.nested.nested_merge(*summary_list)

        for dataset, values in summary.items():
            _log.info(f'{dataset}: {len(values)}')
            assert len(values) == len(
                db.get_dataset(dataset)
            ), 'Number of results needs to match length of dataset!'
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)

        # Compute and save mean of metrics
        means = compute_means(summary)
        mean_json_path = experiment_dir / 'means.json'
        _log.info(f"Saving means to: {mean_json_path}")
        pb.io.dump_json(means, mean_json_path)
Example #20
0
def main(_run, model_path, load_ckpt, batch_size, device, store_misclassified):
    if IS_MASTER:
        commands.print_config(_run)

    model_path = Path(model_path)
    eval_dir = get_new_subdir(model_path / 'eval',
                              id_naming='time',
                              consider_mpi=True)
    # perform evaluation on a sub-set (10%) of the dataset used for training
    config = load_json(model_path / 'config.json')
    database_json = config['database_json']
    dataset = config['dataset']

    model = pt.Model.from_storage_dir(model_path,
                                      checkpoint_name=load_ckpt,
                                      consider_mpi=True)
    model.to(device)
    # Turn on evaluation mode for, e.g., BatchNorm and Dropout modules
    model.eval()

    _, _, test_set = get_datasets(model_path,
                                  database_json,
                                  dataset,
                                  batch_size,
                                  return_indexable=device == 'cpu')
    with torch.no_grad():
        summary = dict(misclassified_examples=dict(),
                       correct_classified_examples=dict(),
                       hits=list())
        for batch in split_managed(test_set,
                                   is_indexable=device == 'cpu',
                                   progress_bar=True,
                                   allow_single_worker=True):
            output = model(pt.data.example_to_device(batch, device))
            prediction = torch.argmax(output, dim=-1).cpu().numpy()
            confidence = torch.softmax(output, dim=-1).max(dim=-1).values.cpu()\
                .numpy()
            label = np.array(batch['speaker_id'])
            hits = (label == prediction).astype('bool')
            summary['hits'].extend(hits.tolist())
            summary['misclassified_examples'].update({
                k: {
                    'true_label': v1,
                    'predicted_label': v2,
                    'audio_path': v3,
                    'confidence': f'{v4:.2%}',
                }
                for k, v1, v2, v3, v4 in zip(
                    np.array(batch['example_id'])[~hits], label[~hits],
                    prediction[~hits],
                    np.array(batch['audio_path'])[~hits], confidence[~hits])
            })
            # for each correct predicted label, collect the audio paths
            correct_classified = summary['correct_classified_examples']
            summary['correct_classified_examples'].update({
                k: correct_classified[k] +
                [v] if k in correct_classified.keys() else [v]
                for k, v in zip(prediction[hits],
                                np.array(batch['audio_path'])[hits])
            })

    summary_list = COMM.gather(summary, root=MASTER)

    if IS_MASTER:
        print(f'\nlen(summary_list): {len(summary_list)}')
        if len(summary_list) > 1:
            summary = dict(
                misclassified_examples=dict(),
                correct_classified_examples=dict(),
                hits=list(),
            )
            for partial_summary in summary_list:
                summary['hits'].extend(partial_summary['hits'])
                summary['misclassified_examples'].update(
                    partial_summary['misclassified_examples'])
                for label, audio_path_list in \
                        partial_summary['correct_classified_examples'].items():
                    summary['correct_classified_examples'].update({
                        label:
                        summary['correct_classified_examples'][label] +
                        audio_path_list if label
                        in summary['correct_classified_examples'].keys() else
                        audio_path_list
                    })
        hits = summary['hits']
        misclassified_examples = summary['misclassified_examples']
        correct_classified_examples = summary['correct_classified_examples']
        accuracy = np.array(hits).astype('float').mean()
        if store_misclassified:
            misclassified_dir = eval_dir / 'misclassified_examples'
            for example_id, v in misclassified_examples.items():
                label, prediction_label, audio_path, _ = v.values()
                try:
                    predicted_speaker_audio_path = \
                        correct_classified_examples[prediction_label][0]
                    example_dir = \
                        misclassified_dir / f'{example_id}_{label}_{prediction_label}'
                    example_dir.mkdir(parents=True)
                    os.symlink(audio_path, example_dir / 'example.wav')
                    os.symlink(predicted_speaker_audio_path,
                               example_dir / 'predicted_speaker_example.wav')
                except KeyError:
                    warnings.warn(
                        'There were no correctly predicted inputs from speaker '
                        f'with speaker label {prediction_label}')
        outputs = dict(
            accuracy=f'{accuracy:.2%} ({np.sum(hits)}/{len(hits)})',
            misclassifications=misclassified_examples,
        )
        print(f'Speaker classification accuracy on test set: {accuracy:.2%}')
        print(f'Wrote results to {eval_dir / "results.json"}')
        dump_json(outputs, eval_dir / 'results.json')
Example #21
0
def write_wavs(dst_dir: Path, wsj0_root: Path, wsj1_root: Path, sample_rate):
    wsj0_root = Path(wsj0_root).expanduser().resolve()
    wsj1_root = Path(wsj1_root).expanduser().resolve()
    dst_dir = Path(dst_dir).expanduser().resolve()
    assert wsj0_root.exists(), wsj0_root
    assert wsj1_root.exists(), wsj1_root

    assert not dst_dir == wsj0_root, (wsj0_root, dst_dir)
    assert not dst_dir == wsj1_root, (wsj1_root, dst_dir)
    # Expect, that the dst_dir does not exist to make sure to not overwrite.
    if dlp_mpi.IS_MASTER:
        dst_dir.mkdir(parents=True, exist_ok=False)

    if dlp_mpi.IS_MASTER:
        # Search for CD numbers, e.g. "13-34.1"
        # CD stands for compact disk.
        cds_0 = list(wsj0_root.rglob("*-*.*"))
        cds_1 = list(wsj1_root.rglob("*-*.*"))
        cds = set(cds_0 + cds_1)

        expected_number_of_files = {
            'pl': 3,
            'ndx': 106,
            'ptx': 3547,
            'dot': 3585,
            'txt': 256
        }
        number_of_written_files = dict()
        for suffix in expected_number_of_files.keys():
            files_0 = list(wsj0_root.rglob(f"*.{suffix}"))
            files_1 = list(wsj1_root.rglob(f"*.{suffix}"))
            files = set(files_0 + files_1)
            # Filter files that do not have a folder that matches "*-*.*".
            files = {
                file
                for file in files
                if any([fnmatch.fnmatch(part, "*-*.*") for part in file.parts])
            }

            # the readme.txt file in the parent directory is not copied
            print(f"About to write ca. {len(files)} {suffix} files.")
            for cd in cds:
                cd_files = list(cd.rglob(f"*.{suffix}"))
                for file in cd_files:
                    target = dst_dir / file.relative_to(cd.parent)
                    target.parent.mkdir(parents=True, exist_ok=True)
                    if not target.is_file():
                        shutil.copy(file, target.parent)
            number_of_written_files[suffix] = len(
                list(dst_dir.rglob(f"*.{suffix}")))
            print(f"Writing {number_of_written_files[suffix]} {suffix} files.")
            print(
                f'Expected {expected_number_of_files[suffix]} {suffix} files.')

        for suffix in expected_number_of_files.keys():
            message = (f'Expected that '
                       f'{expected_number_of_files[suffix]} '
                       f'files with the {suffix} are written. '
                       f'But only {number_of_written_files} are written. ')
            if (number_of_written_files[suffix] !=
                    expected_number_of_files[suffix]):
                warnings.warn(message)

            if suffix == 'pl' and number_of_written_files[suffix] == 1:
                raise RuntimeError(
                    'Found only one pl file although we expected three. '
                    'A typical reason is having only WSJ0. '
                    'Please make sure you have WSJ0+1 = WSJ COMPLETE.')

    if dlp_mpi.IS_MASTER:
        # Ignore .wv2 files since they are not referenced in our database
        # anyway
        wsj_nist_files = [(cd, nist_file) for cd in cds
                          for nist_file in cd.rglob("*.wv1")]
        print(f"About to write {len(wsj_nist_files)} wav files.")
    else:
        wsj_nist_files = None

    wsj_nist_files = dlp_mpi.bcast(wsj_nist_files)

    for nist_file_tuple in dlp_mpi.split_managed(wsj_nist_files):
        cd, nist_file = nist_file_tuple
        assert isinstance(nist_file, Path), nist_file
        signal = read_nist_wsj(nist_file, expected_sample_rate=16000)
        file = nist_file.with_suffix('.wav')
        target = dst_dir / file.relative_to(cd.parent)
        assert not target == nist_file, (nist_file, target)
        target.parent.mkdir(parents=True, exist_ok=True)
        signal = resample_with_sox(signal, rate_in=16000, rate_out=sample_rate)
        # normalization to mean 0:
        signal = signal - np.mean(signal)
        # normalization:
        #   Correction, because the allowed values are in the range [-1, 1).
        #       => "1" is not a vaild value
        correction = (2**15 - 1) / (2**15)
        signal = signal * (correction / np.amax(np.abs(signal)))
        with soundfile.SoundFile(
                str(target),
                samplerate=sample_rate,
                channels=1,
                subtype='FLOAT',
                mode='w',
        ) as f:
            f.write(signal.T)

    dlp_mpi.barrier()
    if dlp_mpi.IS_MASTER:
        created_files = list(set(list(dst_dir.rglob("*.wav"))))
        print(f"Written {len(created_files)} wav files.")
        assert len(wsj_nist_files) == len(created_files), (len(wsj_nist_files),
                                                           len(created_files))
Example #22
0
def write_wavs(dst_dir, db, write_all=False, snr_range=(20, 30)):
    if write_all:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type).mkdir(exist_ok=False)
             for data_type in type_mapper.values()]
        map_fn = partial(
            scenario_map_fn,
            snr_range=snr_range,
            sync_speech_source=True,
            add_speech_reverberation_early=True,
            add_speech_reverberation_tail=True
        )
    else:
        if dlp_mpi.IS_MASTER:
            (dst_dir / 'observation').mkdir(exist_ok=False)
        map_fn = partial(
            scenario_map_fn,
            snr_range=snr_range,
            sync_speech_source=True,
            add_speech_reverberation_early=False,
            add_speech_reverberation_tail=False
        )
    for dataset in ['train_si284', 'cv_dev93', 'test_eval92']:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type / dataset).mkdir(exist_ok=False)
             for data_type in type_mapper.values()]
        ds = db.get_dataset(dataset).map(audio_read).map(map_fn)
        for example in dlp_mpi.split_managed(
                ds, is_indexable=True,
                allow_single_worker=True, progress_bar=True):
            audio_dict = example['audio_data']
            example_id = example['example_id']
            if not write_all:
                del audio_dict['speech_reverberation_early']
                del audio_dict['speech_reverberation_tail']
                del audio_dict['noise_image']
            assert all([np.max(np.abs(v)) <= 1 for v in audio_dict.values()]), (
                example_id, [np.max(np.abs(v)) for v in audio_dict.values()])
            for key, value in audio_dict.items():
                if key not in type_mapper:
                    continue
                path = dst_dir / type_mapper[key] / dataset
                if key in ['observation', 'noise_image']:
                    value = value[None]
                for idx, signal in enumerate(value):
                    appendix = f'_{idx}' if len(value) > 1 else ''
                    filename = example_id + appendix + '.wav'
                    audio_path = str(path / filename)
                    with soundfile.SoundFile(
                            audio_path, subtype='FLOAT', mode='w',
                            samplerate=8000,
                            channels=1 if signal.ndim == 1 else signal.shape[0]
                    ) as f:
                        f.write(signal.T)

        dlp_mpi.barrier()

    if dlp_mpi.IS_MASTER:
        created_files = check_files(dst_dir)
        print(f"Written {len(created_files)} wav files.")
        if write_all:
            expect = (2 * 2 + 2) * 35875
            assert len(created_files) == expect, (
                len(created_files), expect
            )
        else:
            assert len(created_files) == 35875, len(created_files)