コード例 #1
0
ファイル: split_managed.py プロジェクト: fgnt/dlp_mpi
def worker_fails():
    print(f'worker_fails test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    processed = []
    try:
        dlp_mpi.barrier()
        if RANK == 2:
            # Delay rank 2, this ensures that rank 1 gets the first example
            # Does no longer work, becasue in split_managed is COMM.Clone
            # used.
            time.sleep(0.1)
        for i in dlp_mpi.split_managed(examples, progress_bar=False):
            processed.append(i)
            if RANK == 1:
                print(f'let {RANK} fail for data {i}')
                raise ValueError('failed')
            assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
    except ValueError:
        assert RANK in [1], RANK
        assert processed in [[0], [1]], processed
    except AssertionError:
        assert RANK in [0], RANK
        assert processed == [], processed
    else:
        assert RANK in [2], RANK
        assert processed == [1, 2, 3, 4], processed
コード例 #2
0
ファイル: core_chime6.py プロジェクト: yh646492956/pb_chime5
    def enhance_session(self,
                        session_ids,
                        audio_dir,
                        dataset_slice=False,
                        audio_dir_exist_ok=False):
        """

        Args:
            session_ids:
            audio_dir:
            dataset_slice:
            audio_dir_exist_ok:
                When True: It is ok, when the audio dir exists and the files
                insinde may be overwritten.

        Returns:


        >>> enhancer = get_enhancer(wpe=False, bss_iterations=2)
        >>> for x_hat in enhancer.enhance_session('S02'):
        ...     print(x_hat)
        """
        ensure_single_thread_numeric()

        audio_dir = Path(audio_dir)

        it = self.get_iterator(session_ids)

        if dlp_mpi.IS_MASTER:
            audio_dir.mkdir(exist_ok=audio_dir_exist_ok)

            for dataset in set(mapping.session_to_dataset.values()):
                (audio_dir / dataset).mkdir(exist_ok=audio_dir_exist_ok)

        dlp_mpi.barrier()

        if dataset_slice is not False:
            if dataset_slice is True:
                it = it[:2]
            elif isinstance(dataset_slice, int):
                it = it[:dataset_slice]
            elif isinstance(dataset_slice, slice):
                it = it[dataset_slice]
            else:
                raise ValueError(dataset_slice)

        for ex in dlp_mpi.split_managed(it, allow_single_worker=True):
            x_hat = self.enhance_example(ex)
            example_id = ex["example_id"]
            session_id = ex["session_id"]
            dataset = mapping.session_to_dataset[session_id]

            if x_hat.ndim == 1:
                save_path = audio_dir / f'{dataset}' / f'{example_id}.wav'
                dump_audio(
                    x_hat,
                    save_path,
                )
            else:
                raise NotImplementedError(x_hat.shape)
コード例 #3
0
def rirs(
    database_path,
    datasets,
    sample_rate,
    filter_length,
):
    database_path = Path(database_path)

    if dlp_mpi.IS_MASTER:
        scenario_json = database_path / "scenarios.json"
        with scenario_json.open() as f:
            database = json.load(f)
        for dataset in datasets:
            dataset_path = database_path / dataset
            dataset_path.mkdir(parents=True, exist_ok=True)
    else:
        database = None
    database = dlp_mpi.bcast(database)

    for dataset_name, dataset in database['datasets'].items():
        print(f'RANK={dlp_mpi.RANK}, SIZE={dlp_mpi.SIZE}:'
              f' Starting {dataset_name}.')

        for _example_id, example in dlp_mpi.split_managed(
                list(sorted(dataset.items())),
                progress_bar=True,
                is_indexable=True,
        ):
            h = generate_rir(room_dimensions=example['room_dimensions'],
                             source_positions=example['source_position'],
                             sensor_positions=example['sensor_position'],
                             sound_decay_time=example['sound_decay_time'],
                             sample_rate=sample_rate,
                             filter_length=filter_length,
                             sensor_orientations=None,
                             sensor_directivity=None,
                             sound_velocity=343)
            assert not np.any(np.isnan(
                h)), f"{np.sum(np.isnan(h))} values of {h.size} are NaN."

            K, D, T = h.shape
            directory = database_path / dataset_name / _example_id
            directory.mkdir(parents=False, exist_ok=False)

            for k in range(K):
                # Although storing as np.float64 does not allow every reader
                # to access the files, it does not require normalization and
                # we are unsure how much precision is needed for RIRs.
                with soundfile.SoundFile(str(directory / f"h_{k}.wav"),
                                         subtype='DOUBLE',
                                         samplerate=sample_rate,
                                         mode='w',
                                         channels=h.shape[1]) as f:
                    f.write(h[k, :, :].T)

        dlp_mpi.barrier()

        print(f'RANK={dlp_mpi.RANK}, SIZE={dlp_mpi.SIZE}:'
              f' Finished {dataset_name}.')
コード例 #4
0
def pbar():
    print(f'executable test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    def bar(i):
        time.sleep(0.04)
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)

    class MockPbar:
        call_history = []

        def __init__(self):
            self.i = 0

        def set_description(self, text):
            self.call_history.append(text)

        def update(self, inc=1):
            self.i += 1
            self.call_history.append(f'update {self.i}')

    import contextlib

    @contextlib.contextmanager
    def mock_pbar(total, disable):
        assert disable is False, disable
        yield MockPbar()

    import mock

    with mock.patch('tqdm.tqdm', mock_pbar):

        dlp_mpi.barrier()
        if RANK == 2:
            time.sleep(0.02)

        for i in dlp_mpi.map_unordered(
                bar,
                examples,
                progress_bar=True,
        ):
            assert dlp_mpi.RANK in [0], (dlp_mpi.RANK, dlp_mpi.SIZE)

    if RANK == 0:
        assert MockPbar.call_history == [
            'busy: 2', 'update 1', 'update 2', 'update 3', 'update 4',
            'busy: 1', 'update 5', 'busy: 0'
        ], MockPbar.call_history
    else:
        assert MockPbar.call_history == [], MockPbar.call_history
コード例 #5
0
ファイル: split_managed.py プロジェクト: fgnt/dlp_mpi
def cross_communication():
    print(f'cross_communication test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    dlp_mpi.barrier()
    if RANK == 1:
        time.sleep(0.1)
    elif RANK == 2:
        pass

    results = []
    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)

        if RANK == 1:
            results.append(i)
        elif RANK == 2:
            results.append(i)
            time.sleep(0.2)

    if RANK == 1:
        assert results in [[0, 2, 3, 4], [1, 2, 3, 4]], results
    elif RANK == 2:
        assert results in [[1], [0]], results

    for i in dlp_mpi.split_managed(examples, progress_bar=False):
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
        if RANK == 1:
            results.append(i)
            time.sleep(0.001)
        elif RANK == 2:
            results.append(i)
            time.sleep(0.2)

    if RANK == 1:
        assert results in [
            [0, 2, 3, 4, 0, 2, 3, 4],
            [0, 2, 3, 4, 1, 2, 3, 4],
            [1, 2, 3, 4, 0, 2, 3, 4],
            [1, 2, 3, 4, 1, 2, 3, 4],
        ], results
    elif RANK == 2:
        assert results in [
            [1, 1],
            [1, 0],
            [0, 1],
            [0, 0],
        ], results
コード例 #6
0
def evaluate_model(dataset, model, get_sad_fn,
                   get_target_fn=lambda x: x['activation'],
                   num_thresholds=201, buffer_zone=0.5,
                   is_indexable=True, allow_single_worker=True,
                   sample_rate=8000):

    tp_fp_tn_fn = np.zeros((num_thresholds, 4), dtype=int)

    import dlp_mpi
    for example in dlp_mpi.split_managed(
        dataset, is_indexable=is_indexable,
        allow_single_worker=allow_single_worker,
    ):
        target = get_target_fn(example)
        adjusted_target = adjust_annotation_fn(
            target, buffer_zone=buffer_zone,
            sample_rate=sample_rate
        )
        model_out = model(example)
        for idx, th in enumerate(np.linspace(0, 1, num_thresholds)):
            th = np.round(th, 2)
            sad = get_sad_fn(model_out, th, example)
            out = get_tp_fp_tn_fn(
                adjusted_target, sad,
                sample_rate=sample_rate, adjust_annotation=False
            )
            tp_fp_tn_fn[idx] = [tp_fp_tn_fn[idx][idy] + o for idy, o in
                                enumerate(out)]

    dlp_mpi.barrier()
    tp_fp_tn_fn_gather = dlp_mpi.gather(tp_fp_tn_fn, root=dlp_mpi.MASTER)
    if dlp_mpi.IS_MASTER:
        tp_fp_tn_fn = np.zeros((num_thresholds, 4), dtype=int)
        for array in tp_fp_tn_fn_gather:
            tp_fp_tn_fn += array
    else:
        tp_fp_tn_fn = None
    return tp_fp_tn_fn
コード例 #7
0
def worker_fails():
    print(f'executable test {RANK}')

    examples = list(range(5))

    ranks = dlp_mpi.gather(dlp_mpi.RANK)
    if dlp_mpi.IS_MASTER:
        assert ranks == [0, 1, 2], ranks

    def bar(i):
        if RANK == 1:
            print(f'let {RANK} fail for data {i}')
            raise ValueError('failed')
        assert dlp_mpi.RANK in [1, 2], (dlp_mpi.RANK, dlp_mpi.SIZE)
        return i, RANK

    processed = []
    try:
        dlp_mpi.barrier()
        if RANK == 2:
            # Delay rank 2, this ensures that rank 1 gets the first example
            time.sleep(0.1)
        for i, worker_rank in dlp_mpi.map_unordered(bar, examples):
            print(
                f'Loop body from {RANK} for data {i} that was processed by {worker_rank}'
            )
            assert dlp_mpi.RANK in [0], (dlp_mpi.RANK, dlp_mpi.SIZE)
            processed.append(i)
    except ValueError:
        assert RANK in [1], RANK
    except AssertionError:
        assert RANK in [0], RANK
        # Example zero failed for worker 1, but the master process fails at the
        # end of the for loop. So examples 1 to 4 are processed
        assert processed == [1, 2, 3, 4], processed
    else:
        assert RANK in [2], RANK
コード例 #8
0
ファイル: write_wav.py プロジェクト: theLittleTiger/sms_wsj
def write_wavs(dst_dir: Path, wsj0_root: Path, wsj1_root: Path, sample_rate):
    wsj0_root = Path(wsj0_root).expanduser().resolve()
    wsj1_root = Path(wsj1_root).expanduser().resolve()
    dst_dir = Path(dst_dir).expanduser().resolve()
    assert wsj0_root.exists(), wsj0_root
    assert wsj1_root.exists(), wsj1_root

    assert not dst_dir == wsj0_root, (wsj0_root, dst_dir)
    assert not dst_dir == wsj1_root, (wsj1_root, dst_dir)
    # Expect, that the dst_dir does not exist to make sure to not overwrite.
    if dlp_mpi.IS_MASTER:
        dst_dir.mkdir(parents=True, exist_ok=False)

    if dlp_mpi.IS_MASTER:
        # Search for CD numbers, e.g. "13-34.1"
        # CD stands for compact disk.
        cds_0 = list(wsj0_root.rglob("*-*.*"))
        cds_1 = list(wsj1_root.rglob("*-*.*"))
        cds = set(cds_0 + cds_1)

        expected_number_of_files = {
            'pl': 3,
            'ndx': 106,
            'ptx': 3547,
            'dot': 3585,
            'txt': 256
        }
        number_of_written_files = dict()
        for suffix in expected_number_of_files.keys():
            files_0 = list(wsj0_root.rglob(f"*.{suffix}"))
            files_1 = list(wsj1_root.rglob(f"*.{suffix}"))
            files = set(files_0 + files_1)
            # Filter files that do not have a folder that matches "*-*.*".
            files = {
                file
                for file in files
                if any([fnmatch.fnmatch(part, "*-*.*") for part in file.parts])
            }

            # the readme.txt file in the parent directory is not copied
            print(f"About to write ca. {len(files)} {suffix} files.")
            for cd in cds:
                cd_files = list(cd.rglob(f"*.{suffix}"))
                for file in cd_files:
                    target = dst_dir / file.relative_to(cd.parent)
                    target.parent.mkdir(parents=True, exist_ok=True)
                    if not target.is_file():
                        shutil.copy(file, target.parent)
            number_of_written_files[suffix] = len(
                list(dst_dir.rglob(f"*.{suffix}")))
            print(f"Writing {number_of_written_files[suffix]} {suffix} files.")
            print(
                f'Expected {expected_number_of_files[suffix]} {suffix} files.')

        for suffix in expected_number_of_files.keys():
            message = (f'Expected that '
                       f'{expected_number_of_files[suffix]} '
                       f'files with the {suffix} are written. '
                       f'But only {number_of_written_files} are written. ')
            if (number_of_written_files[suffix] !=
                    expected_number_of_files[suffix]):
                warnings.warn(message)

            if suffix == 'pl' and number_of_written_files[suffix] == 1:
                raise RuntimeError(
                    'Found only one pl file although we expected three. '
                    'A typical reason is having only WSJ0. '
                    'Please make sure you have WSJ0+1 = WSJ COMPLETE.')

    if dlp_mpi.IS_MASTER:
        # Ignore .wv2 files since they are not referenced in our database
        # anyway
        wsj_nist_files = [(cd, nist_file) for cd in cds
                          for nist_file in cd.rglob("*.wv1")]
        print(f"About to write {len(wsj_nist_files)} wav files.")
    else:
        wsj_nist_files = None

    wsj_nist_files = dlp_mpi.bcast(wsj_nist_files)

    for nist_file_tuple in dlp_mpi.split_managed(wsj_nist_files):
        cd, nist_file = nist_file_tuple
        assert isinstance(nist_file, Path), nist_file
        signal = read_nist_wsj(nist_file, expected_sample_rate=16000)
        file = nist_file.with_suffix('.wav')
        target = dst_dir / file.relative_to(cd.parent)
        assert not target == nist_file, (nist_file, target)
        target.parent.mkdir(parents=True, exist_ok=True)
        signal = resample_with_sox(signal, rate_in=16000, rate_out=sample_rate)
        # normalization to mean 0:
        signal = signal - np.mean(signal)
        # normalization:
        #   Correction, because the allowed values are in the range [-1, 1).
        #       => "1" is not a vaild value
        correction = (2**15 - 1) / (2**15)
        signal = signal * (correction / np.amax(np.abs(signal)))
        with soundfile.SoundFile(
                str(target),
                samplerate=sample_rate,
                channels=1,
                subtype='FLOAT',
                mode='w',
        ) as f:
            f.write(signal.T)

    dlp_mpi.barrier()
    if dlp_mpi.IS_MASTER:
        created_files = list(set(list(dst_dir.rglob("*.wav"))))
        print(f"Written {len(created_files)} wav files.")
        assert len(wsj_nist_files) == len(created_files), (len(wsj_nist_files),
                                                           len(created_files))
コード例 #9
0
    def enhance_session(
        self,
        session_ids,
        audio_dir,
        dataset_slice=False,
        audio_dir_exist_ok=False,
        is_chime=True,
    ):
        """

        Args:
            session_ids:
            audio_dir:
            dataset_slice:
            audio_dir_exist_ok:
                When True: It is ok, when the audio dir exists and the files
                insinde may be overwritten.
            is_chime:
                If true, map the session_id to the dataset name for the folder
                naming. Otherwise keep the session_id for the folder name.
        Returns:


        >>> enhancer = get_enhancer(wpe=False, bss_iterations=2)
        >>> for x_hat in enhancer.enhance_session('S02'):
        ...     print(x_hat)
        """
        ensure_single_thread_numeric()

        audio_dir = Path(audio_dir)

        it = self.get_dataset(session_ids)

        if dlp_mpi.IS_MASTER:
            audio_dir.mkdir(exist_ok=audio_dir_exist_ok)
            # for dataset in self.db.data['alias']:
            #     (audio_dir / dataset).mkdir(exist_ok=audio_dir_exist_ok)

        dlp_mpi.barrier()

        if dataset_slice is not False:
            if dataset_slice is True:
                it = it[:2]
            elif isinstance(dataset_slice, int):
                it = it[:dataset_slice]
            elif isinstance(dataset_slice, slice):
                it = it[dataset_slice]
            else:
                raise ValueError(dataset_slice)

        for ex in dlp_mpi.split_managed(it, allow_single_worker=True):
            try:
                x_hat = self.enhance_example(ex)
                example_id = ex["example_id"]
                session_id = ex["session_id"]
                if is_chime:
                    dataset = mapping.session_to_dataset[session_id]
                else:
                    dataset = session_id

                if x_hat.ndim == 1:
                    save_path = audio_dir / f'{dataset}' / f'{example_id}.wav'
                    dump_audio(x_hat, save_path, mkdir=True)
                else:
                    raise NotImplementedError(x_hat.shape)
            except Exception:
                print('ERROR: Failed example:', ex['example_id'])
                raise
コード例 #10
0
ファイル: write_wav.py プロジェクト: suwoncjh/sms_wsj
def write_wavs(dst_dir: Path, wsj0_root: Path, wsj1_root: Path, sample_rate):
    wsj0_root = Path(wsj0_root).expanduser().resolve()
    wsj1_root = Path(wsj1_root).expanduser().resolve()
    dst_dir = Path(dst_dir).expanduser().resolve()
    assert wsj0_root.exists(), wsj0_root
    assert wsj1_root.exists(), wsj1_root

    assert not dst_dir == wsj0_root, (wsj0_root, dst_dir)
    assert not dst_dir == wsj1_root, (wsj1_root, dst_dir)
    # Expect, that the dst_dir does not exist to make sure to not overwrite.
    if dlp_mpi.IS_MASTER:
        dst_dir.mkdir(parents=True, exist_ok=False)

    if dlp_mpi.IS_MASTER:
        cds_0 = list(wsj0_root.rglob("*-*.*"))
        cds_1 = list(wsj1_root.rglob("*-*.*"))
        cds = set(cds_0 + cds_1)
        for suffix in 'pl ndx ptx dot txt'.split():
            files_0 = list(wsj0_root.rglob(f"*.{suffix}"))
            files_1 = list(wsj1_root.rglob(f"*.{suffix}"))
            files = set(files_0 + files_1)
            # the readme.txt file in the parent directory is not copied
            print(f"About to write ca. {len(files)} {suffix} files.")
            for cd in cds:
                cd_files = list(cd.rglob(f"*.{suffix}"))
                for file in cd_files:
                    target = dst_dir / file.relative_to(cd.parent)
                    target.parent.mkdir(parents=True, exist_ok=True)
                    if not target.is_file():
                        shutil.copy(file, target.parent)
            written_files = list(dst_dir.rglob(f"*.{suffix}"))
            print(f"Writing {len(written_files)} {suffix} files.")
            # assert len(written_files) == len(files), (files, written_files)

    if dlp_mpi.IS_MASTER:
        # Ignore .wv2 files since they are not referenced in our database
        # anyway
        wsj_nist_files = [(cd, nist_file) for cd in cds
                          for nist_file in cd.rglob("*.wv1")]
        print(f"About to write {len(wsj_nist_files)} wav files.")
    else:
        wsj_nist_files = None

    wsj_nist_files = dlp_mpi.bcast(wsj_nist_files)

    for nist_file_tuple in dlp_mpi.split_managed(wsj_nist_files):
        cd, nist_file = nist_file_tuple
        assert isinstance(nist_file, Path), nist_file
        signal = read_nist_wsj(nist_file, expected_sample_rate=16000)
        file = nist_file.with_suffix('.wav')
        target = dst_dir / file.relative_to(cd.parent)
        assert not target == nist_file, (nist_file, target)
        target.parent.mkdir(parents=True, exist_ok=True)
        # normalization:
        #   Correction, because the allowed values are in the range [-1, 1).
        #       => "1" is not a vaild value
        signal = resample_with_sox(signal, rate_in=16000, rate_out=sample_rate)
        correction = (2 ** 15 - 1) / (2 ** 15)
        signal = signal * (correction / np.amax(np.abs(signal)))
        with soundfile.SoundFile(
                str(target), samplerate=sample_rate, channels=1,
                subtype='FLOAT', mode='w',
        ) as f:
            f.write(signal.T)

    dlp_mpi.barrier()
    if dlp_mpi.IS_MASTER:
        created_files = list(set(list(dst_dir.rglob("*.wav"))))
        print(f"Written {len(created_files)} wav files.")
        assert len(wsj_nist_files) == len(created_files), (len(wsj_nist_files), len(created_files))
コード例 #11
0
ファイル: split_managed.py プロジェクト: fgnt/dlp_mpi
    for i in examples:
        total += i
    elapsed = time.perf_counter() - start

    time_per_example = elapsed / 1000
    py_examples_per_second = 1 / time_per_example

    assert py_examples_per_second >= 250_000, py_examples_per_second
    assert py_examples_per_second <= 9_000_000, py_examples_per_second


if __name__ == '__main__':
    from dlp_mpi.testing import test_relaunch_with_mpi
    test_relaunch_with_mpi()

    dlp_mpi.barrier()
    executable()
    dlp_mpi.barrier()
    speedup()
    dlp_mpi.barrier()
    cross_communication()
    dlp_mpi.barrier()
    worker_fails()
    dlp_mpi.barrier()
    pbar()
    dlp_mpi.barrier()
    overhead()
    dlp_mpi.barrier()

    # ToDo: find a way to test the progress bar. Maybe with mock?
コード例 #12
0
def write_wavs(dst_dir, db, write_all=False, snr_range=(20, 30)):
    if write_all:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type).mkdir(exist_ok=False)
             for data_type in type_mapper.values()]
        map_fn = partial(
            scenario_map_fn,
            snr_range=snr_range,
            sync_speech_source=True,
            add_speech_reverberation_early=True,
            add_speech_reverberation_tail=True
        )
    else:
        if dlp_mpi.IS_MASTER:
            (dst_dir / 'observation').mkdir(exist_ok=False)
        map_fn = partial(
            scenario_map_fn,
            snr_range=snr_range,
            sync_speech_source=True,
            add_speech_reverberation_early=False,
            add_speech_reverberation_tail=False
        )
    for dataset in ['train_si284', 'cv_dev93', 'test_eval92']:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type / dataset).mkdir(exist_ok=False)
             for data_type in type_mapper.values()]
        ds = db.get_dataset(dataset).map(audio_read).map(map_fn)
        for example in dlp_mpi.split_managed(
                ds, is_indexable=True,
                allow_single_worker=True, progress_bar=True):
            audio_dict = example['audio_data']
            example_id = example['example_id']
            if not write_all:
                del audio_dict['speech_reverberation_early']
                del audio_dict['speech_reverberation_tail']
                del audio_dict['noise_image']
            assert all([np.max(np.abs(v)) <= 1 for v in audio_dict.values()]), (
                example_id, [np.max(np.abs(v)) for v in audio_dict.values()])
            for key, value in audio_dict.items():
                if key not in type_mapper:
                    continue
                path = dst_dir / type_mapper[key] / dataset
                if key in ['observation', 'noise_image']:
                    value = value[None]
                for idx, signal in enumerate(value):
                    appendix = f'_{idx}' if len(value) > 1 else ''
                    filename = example_id + appendix + '.wav'
                    audio_path = str(path / filename)
                    with soundfile.SoundFile(
                            audio_path, subtype='FLOAT', mode='w',
                            samplerate=8000,
                            channels=1 if signal.ndim == 1 else signal.shape[0]
                    ) as f:
                        f.write(signal.T)

        dlp_mpi.barrier()

    if dlp_mpi.IS_MASTER:
        created_files = check_files(dst_dir)
        print(f"Written {len(created_files)} wav files.")
        if write_all:
            expect = (2 * 2 + 2) * 35875
            assert len(created_files) == expect, (
                len(created_files), expect
            )
        else:
            assert len(created_files) == 35875, len(created_files)
コード例 #13
0
def write_wavs(dst_dir, json_path, write_all=False, snr_range=(20, 30)):
    db = JsonDatabase(json_path)
    if write_all:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type).mkdir(exist_ok=False)
             for data_type in KEY_MAPPER.values()]
        map_fn = partial(scenario_map_fn,
                         snr_range=snr_range,
                         sync_speech_source=True,
                         add_speech_reverberation_early=True,
                         add_speech_reverberation_tail=True)
    else:
        if dlp_mpi.IS_MASTER:
            (dst_dir / 'observation').mkdir(exist_ok=False)
        map_fn = partial(scenario_map_fn,
                         snr_range=snr_range,
                         sync_speech_source=True,
                         add_speech_reverberation_early=False,
                         add_speech_reverberation_tail=False)
    for dataset in ['train_si284', 'cv_dev93', 'test_eval92']:
        if dlp_mpi.IS_MASTER:
            [(dst_dir / data_type / dataset).mkdir(exist_ok=False)
             for data_type in KEY_MAPPER.values()]
        ds = db.get_dataset(dataset).map(audio_read).map(map_fn)
        for example in dlp_mpi.split_managed(
                ds,
                is_indexable=True,
                allow_single_worker=True,
                progress_bar=True,
        ):
            audio_dict = example['audio_data']
            example_id = example['example_id']
            if not write_all:
                del audio_dict['speech_reverberation_early']
                del audio_dict['speech_reverberation_tail']
                del audio_dict['noise_image']

            def get_abs_max(a):
                if isinstance(a, np.ndarray):
                    if a.dtype == np.object:
                        return np.max(list(map(get_abs_max, a)))
                    else:
                        return np.max(np.abs(a))
                elif isinstance(a, (tuple, list)):
                    return np.max(list(map(get_abs_max, a)))
                elif isinstance(a, dict):
                    return np.max(list(map(get_abs_max, a.values())))
                else:
                    raise TypeError(a)

            assert get_abs_max(audio_dict), (example_id, {
                k: get_abs_max(v)
                for k, v in audio_dict.items()
            })
            for key, value in audio_dict.items():
                if key not in KEY_MAPPER:
                    continue
                path = dst_dir / KEY_MAPPER[key] / dataset
                if key in ['observation', 'noise_image']:
                    value = value[None]
                for idx, signal in enumerate(value):
                    appendix = f'_{idx}' if len(value) > 1 else ''
                    filename = example_id + appendix + '.wav'
                    audio_path = str(path / filename)
                    with soundfile.SoundFile(audio_path,
                                             subtype='FLOAT',
                                             mode='w',
                                             samplerate=8000,
                                             channels=1 if signal.ndim == 1
                                             else signal.shape[0]) as f:
                        f.write(signal.T)

        dlp_mpi.barrier()

    if dlp_mpi.IS_MASTER:
        created_files = check_files(dst_dir)
        print(f"Written {len(created_files)} wav files.")
        if write_all:
            # TODO Less, if you do a test run.
            num_speakers = 2  # todo infer num_speakers from json
            # 2 files for: early, tail, speech_source
            # 1 file for: observation, noise
            expect = (3 * num_speakers + 2) * 35875
            assert len(created_files) == expect, (len(created_files), expect)
        else:
            assert len(created_files) == 35875, len(created_files)