Exemplo n.º 1
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset in datasets:
            iterable = prepare_iterable(
                db,
                dataset,
                batch_size,
                return_keys=None,
                prefetch=False,
            )

            for batch in dlp_mpi.split_managed(iterable,
                                               is_indexable=False,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(pt.data.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]
                entry['metrics'] \
                    = pb_bss.evaluation.OutputMetrics(speech_prediction=z,
                                                      speech_source=s).as_dict()

        summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        print(f'len(summary_list): {len(summary_list)}')
        for partial_summary in summary_list:
            for dataset, values in partial_summary.items():
                summary[dataset].update(values)

        for dataset, values in summary.items():
            print(f'{dataset}: {len(values)}')

        result_json_path = experiment_dir / 'result.json'
        print(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)
Exemplo n.º 2
0
 def test_restore_time_signal_from_torch_stft_and_numpy_istft(self):
     X_torch = self.stft(self.torch_signal).numpy()
     X_numpy = X_torch[..., :self.fbins] + 1j * X_torch[..., self.fbins:]
     x_numpy = istft(X_numpy, size=self.size, shift=self.shift,
                     window_length=self.window_length, window=self.window,
                     fading=self.fading)[..., :self.time_signal.shape[-1]]
     tc.assert_almost_equal(x_numpy, self.time_signal)
Exemplo n.º 3
0
def main(_run, batch_size, datasets, debug, experiment_dir):
    experiment_dir = Path(experiment_dir)

    if IS_MASTER:
        sacred.commands.print_config(_run)

    # TODO: Substantially faster to load the model once and distribute via MPI
    model = get_model()
    db = MerlMixtures()

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset in datasets:
            iterable = prepare_iterable(db,
                                        dataset,
                                        batch_size,
                                        return_keys=None,
                                        prefetch=False,
                                        iterator_slice=slice(
                                            RANK, 20 if debug else None, SIZE))
            iterable = tqdm(iterable,
                            total=len(iterable),
                            disable=not IS_MASTER)
            for batch in iterable:
                entry = dict()
                model_output = model(pt.data.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]
                entry['mir_eval'] \
                    = pb.evaluation.mir_eval_sources(s, z, return_dict=True)

                summary[dataset][example_id] = entry

    summary_list = COMM.gather(summary, root=MASTER)

    if IS_MASTER:
        print(f'len(summary_list): {len(summary_list)}')
        for partial_summary in summary_list:
            for dataset, values in partial_summary.items():
                summary[dataset].update(values)

        for dataset, values in summary.items():
            print(f'{dataset}: {len(values)}')

        result_json_path = experiment_dir / 'result.json'
        print(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)
Exemplo n.º 4
0
def play(
    data,
    channel=0,
    sample_rate=16000,
    size=1024,
    shift=256,
    window='blackman',
    window_length: int = None,
    *,
    scale=1,
    name=None,
    stereo=False,
    normalize=True,
):
    """ Tries to guess, what the input data is. Plays time series and stft.

    Provides an easy to use interface to play back sound in an IPython Notebook.

    :param data: Time series with shape (frames,)
        or stft with shape (frames, channels, bins) or (frames, bins)
        or string containing path to audio file.
    :param channel: Channel, if you have a multichannel stft signal or a
        multichannel audio file.
    :param sample_rate: Sampling rate in Hz.
    :param size: STFT window size
    :param shift: STFT shift
    :param window: STFT analysis window
    :param scale: Scale the Volume, currently only amplification with clip
        is supported.
    :param name: If name is set, then in ipynb table with name and audio is
                 displayed
    :param stereo: If set to true, you can listen to channel as defined by
        `channel` parameter and the next channel at the same time.
    :param normalize: It true, normalize the data to have values in the range
        from 0 to 1. Can only be disabled with newer IPython versions.
    :return:
    """
    if isinstance(data, dict):
        assert name is None, name
        for k, v in data.items():
            play(
                data=v,
                name=k,
                channel=channel,
                sample_rate=sample_rate,
                size=size,
                shift=shift,
                window=window,
                window_length=window_length,
                scale=scale,
                stereo=stereo,
            )
        return

    if stereo:
        if isinstance(channel, int):
            channel = (channel, channel + 1)
    else:
        assert isinstance(channel, int), (type(channel), channel)

    if isinstance(data, Path):
        data = str(data)

    if isinstance(data, str):
        assert os.path.exists(data), ('File does not exist.', data)
        data = load_audio(data, expected_sample_rate=sample_rate)
        if len(data.shape) == 2:
            data = data[channel, :]
    elif np.iscomplexobj(data):
        from paderbox.transform import istft

        assert data.shape[-1] == size // 2 + \
            1, ('Wrong number of frequency bins', data.shape, size)

        if len(data.shape) == 3:
            data = data[:, channel, :]

        data = istft(
            data,
            size=size,
            shift=shift,
            window=window,
            window_length=window_length,
        )
    elif np.isrealobj(data):
        if len(data.shape) == 2:
            data = data[channel, :]

    assert np.isrealobj(data), data.dtype
    assert stereo or len(data.shape) == 1, data.shape

    if scale != 1:
        assert scale > 1 or (not normalize), \
            'Only Amplification with clipping is supported. \n' \
            'Note: IPython.display.Audio scales the input, therefore a ' \
            'np.clip can increase the power, but not decrease it. ' \
            f'scale={scale}'
        max_abs_data = np.max(np.abs(data))
        data = np.clip(data, -max_abs_data / scale, max_abs_data / scale)

    if stereo:
        assert len(data.shape) == 2, data.shape
        assert data.shape[0] == 2, data.shape

    if normalize:
        # ToDo: disable this version specific check
        # ipython 7.3.0 has no `normalize` argument and normalize couldn't
        # be disabled
        kwargs = {}
    else:
        # ipython 7.12.0 `Audio` has a `normalize` argument see
        # https://github.com/ipython/ipython/pull/11650
        kwargs = {'normalize': normalize}

    from IPython.display import display
    from IPython.display import Audio
    if name is None:
        display(Audio(data, rate=sample_rate, **kwargs))
    else:

        class NamedAudio(Audio):
            name = None

            def _repr_html_(self):
                audio_html = super()._repr_html_()

                assert self.name is not None

                return """
                <table style="width:100%">
                    <tr>
                        <td style="width:25%">
                            {}
                        </td>
                        <td style="width:75%">
                            {}
                        </td>
                    </tr>
                </table>
                """.format(self.name, audio_html)

        na = NamedAudio(data, rate=sample_rate, **kwargs)
        na.name = name
        display(na)
Exemplo n.º 5
0
def main(_run, batch_size, datasets, debug, experiment_dir, database_json,
         _log):
    experiment_dir = Path(experiment_dir)

    if dlp_mpi.IS_MASTER:
        sacred.commands.print_config(_run)

    model = get_model()
    db = JsonDatabase(json_path=database_json)

    model.eval()
    with torch.no_grad():
        summary = defaultdict(dict)
        for dataset_name in datasets:
            dataset = prepare_dataset(db,
                                      dataset_name,
                                      batch_size,
                                      return_keys=None,
                                      prefetch=False,
                                      shuffle=False)

            for batch in dlp_mpi.split_managed(dataset,
                                               is_indexable=True,
                                               progress_bar=True,
                                               allow_single_worker=debug):
                entry = dict()
                model_output = model(model.example_to_device(batch))

                example_id = batch['example_id'][0]
                s = batch['s'][0]
                Y = batch['Y'][0]
                mask = model_output[0].numpy()

                Z = mask * Y[:, None, :]
                z = istft(einops.rearrange(Z, "t k f -> k t f"),
                          size=512,
                          shift=128)

                s = s[:, :z.shape[1]]
                z = z[:, :s.shape[1]]

                input_metrics = pb_bss.evaluation.InputMetrics(
                    observation=batch['y'][0][None, :],
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )

                output_metrics = pb_bss.evaluation.OutputMetrics(
                    speech_prediction=z,
                    speech_source=s,
                    sample_rate=8000,
                    enable_si_sdr=False,
                )
                entry['input'] = dict(mir_eval=input_metrics.mir_eval, )
                entry['output'] = dict(mir_eval={
                    k: v
                    for k, v in output_metrics.mir_eval.items()
                    if k != 'selection'
                }, )

                entry['improvement'] = pb.utils.nested.nested_op(
                    operator.sub,
                    entry['output'],
                    entry['input'],
                )
                entry['selection'] = output_metrics.mir_eval['selection']

                summary[dataset][example_id] = entry

    summary_list = dlp_mpi.gather(summary, root=dlp_mpi.MASTER)

    if dlp_mpi.IS_MASTER:
        _log.info(f'len(summary_list): {len(summary_list)}')
        summary = pb.utils.nested.nested_merge(*summary_list)

        for dataset, values in summary.items():
            _log.info(f'{dataset}: {len(values)}')
            assert len(values) == len(
                db.get_dataset(dataset)
            ), 'Number of results needs to match length of dataset!'
        result_json_path = experiment_dir / 'result.json'
        _log.info(f"Exporting result: {result_json_path}")
        pb.io.dump_json(summary, result_json_path)

        # Compute and save mean of metrics
        means = compute_means(summary)
        mean_json_path = experiment_dir / 'means.json'
        _log.info(f"Saving means to: {mean_json_path}")
        pb.io.dump_json(means, mean_json_path)