コード例 #1
0
def test_transform_to_separation_model(musdb_tracks):
    track = musdb_tracks[10]
    mix, sources = nussl.utils.musdb_track_to_audio_signals(track)

    data = {
        'mix': mix,
        'sources': sources,
        'metadata': {'labels': []}
    }

    msa = transforms.MagnitudeSpectrumApproximation()
    tdl = transforms.ToSeparationModel()
    assert tdl.__class__.__name__ in str(tdl)

    com = transforms.Compose([msa, tdl])

    data = com(data)
    accepted_keys = ['mix_magnitude', 'source_magnitudes']
    rejected_keys = ['mix', 'sources', 'metadata']

    for a in accepted_keys:
        assert a in data
    for r in rejected_keys:
        assert r not in data

    for key in data:
        assert torch.is_tensor(data[key])
        assert data[key].shape[0] == mix.stft().shape[1]
        assert data[key].shape[1] == mix.stft().shape[0]
コード例 #2
0
ファイル: test_base_dataset.py プロジェクト: speechdnn/nussl
def test_dataset_base(benchmark_audio, monkeypatch):
    keys = [benchmark_audio[k] for k in benchmark_audio]

    def dummy_get(self, folder):
        return keys

    pytest.raises(DataSetException, initialize_bad_dataset_and_run)

    monkeypatch.setattr(BadDataset, 'get_items', dummy_get)
    pytest.raises(DataSetException, initialize_bad_dataset_and_run)

    monkeypatch.setattr(BadDataset, 'process_item', dummy_process_item)
    pytest.raises(transforms.TransformException,
                  initialize_bad_dataset_and_run)

    monkeypatch.setattr(BaseDataset, 'get_items', dummy_get)
    monkeypatch.setattr(BaseDataset, 'process_item', dummy_process_item)

    _dataset = BaseDataset('test')

    assert len(_dataset) == len(keys)

    audio_signal = nussl.AudioSignal(keys[0])
    assert _dataset[0]['mix'] == audio_signal

    _dataset = BaseDataset('test', transform=BadTransform())
    pytest.raises(transforms.TransformException, _dataset.__getitem__, 0)

    psa = transforms.MagnitudeSpectrumApproximation()
    _dataset = BaseDataset('test', transform=psa)

    output = _dataset[0]
    assert 'source_magnitudes' in output
    assert 'mix_magnitude' in output
    assert 'ideal_binary_mask' in output

    monkeypatch.setattr(BaseDataset, 'process_item',
                        dummy_process_item_by_audio)
    psa = transforms.MagnitudeSpectrumApproximation()
    _dataset = BaseDataset('test', transform=psa)

    output = _dataset[0]
    assert 'source_magnitudes' in output
    assert 'mix_magnitude' in output
    assert 'ideal_binary_mask' in output
コード例 #3
0
def test_transform_msa_psa(musdb_tracks):
    track = musdb_tracks[10]
    mix, sources = nussl.utils.musdb_track_to_audio_signals(track)

    data = {'mix': mix, 'sources': sources}

    msa = transforms.MagnitudeSpectrumApproximation()
    psa = transforms.PhaseSensitiveSpectrumApproximation()

    assert msa.__class__.__name__ in str(msa)
    assert psa.__class__.__name__ in str(psa)

    pytest.raises(TransformException, psa, {'sources': 'blah'})
    pytest.raises(TransformException, msa, {'sources': 'blah'})

    _data = {'mix': mix}
    output = msa(_data)
    assert np.allclose(output['mix_magnitude'], np.abs(mix.stft()))

    output = msa(data)
    assert np.allclose(output['mix_magnitude'], np.abs(mix.stft()))
    assert list(data['sources'].keys()) == sorted(list(sources.keys()))

    masks = []
    estimates = []

    shape = mix.stft_data.shape + (len(sources), )

    mix_masks = np.ones(shape)
    mix_scores = separate_and_evaluate(mix, data['sources'], mix_masks)

    ibm_scores = separate_and_evaluate(mix, data['sources'],
                                       data['ideal_binary_mask'])
    output['source_magnitudes'] += 1e-8

    mask_data = (output['source_magnitudes'] / np.maximum(
        output['mix_magnitude'][..., None], output['source_magnitudes']))
    msa_scores = separate_and_evaluate(mix, data['sources'], mask_data)

    _data = {'mix': mix}
    output = psa(_data)
    assert np.allclose(output['mix_magnitude'], np.abs(mix.stft()))

    output = psa(data)
    assert np.allclose(output['mix_magnitude'], np.abs(mix.stft()))
    assert list(data['sources'].keys()) == sorted(list(sources.keys()))

    output['source_magnitudes'] += 1e-8

    mask_data = (output['source_magnitudes'] / np.maximum(
        output['mix_magnitude'][..., None], output['source_magnitudes']))
    psa_scores = separate_and_evaluate(mix, data['sources'], mask_data)

    for key in msa_scores:
        if key in ['SI-SDR', 'SI-SIR', 'SI-SAR']:
            diff = np.array(psa_scores[key]) - np.array(mix_scores[key])
            assert diff.mean() > 10
コード例 #4
0
def transform(stft_params: nussl.STFTParams,
              sample_rate: int,
              target_instrument,
              only_audio_signal: bool,
              mask_type: str = 'msa',
              audio_only: bool = False):
    """
    Builds transforms that get applied to
    training and validation datasets.

    Parameters
    ----------
    stft_params : nussl.STFTParams
        Parameters of STFT (see: signal).
    sample_rate : int
        Sample rate of audio signal (see: signal).
    target_instrument : str
        Which instrument to learn to separate out of
        a mixture.
    only_audio_signal : bool
        Whether to return only the audio signals, no
        tensors (useful for eval).
    mask_type : str, optional
        What type of masking to use. Either phase
        sensitive spectrum approx. (psa) or
        magnitude spectrum approx (msa), by default
        'msa'.
    audio_only : bool, optional
        Whether or not to only apply GetAudio in
        transform (don't compute STFTs).
    """
    tfm = []

    other_labels = [k for k in LABELS if k != target_instrument]
    tfm.append(nussl_tfm.SumSources([other_labels]))
    new_labels = [target_instrument] + tfm[-1].group_names
    new_labels = sorted(new_labels)

    if not only_audio_signal:
        if not audio_only:
            if mask_type == 'psa':
                tfm.append(nussl_tfm.PhaseSensitiveSpectrumApproximation())
            elif mask_type == 'msa':
                tfm.append(nussl_tfm.MagnitudeSpectrumApproximation())
            tfm.append(nussl_tfm.MagnitudeWeights())

        tfm.append(nussl_tfm.GetAudio())
        target_index = new_labels.index(target_instrument)
        tfm.append(nussl_tfm.IndexSources('source_magnitudes', target_index))

        tfm.append(nussl_tfm.ToSeparationModel())

    return nussl_tfm.Compose(tfm), new_labels
コード例 #5
0
def test_transform_compose(musdb_tracks):
    track = musdb_tracks[10]
    mix, sources = nussl.utils.musdb_track_to_audio_signals(track)

    data = {
        'mix': mix,
        'sources': sources,
        'metadata': {
            'labels': ['bass', 'drums', 'other', 'vocals']
        }
    }

    class _BadTransform(object):
        def __call__(self, data):
            return 'not a dictionary'

    com = transforms.Compose([_BadTransform()])
    pytest.raises(TransformException, com, data)

    msa = transforms.MagnitudeSpectrumApproximation()
    tfm = transforms.SumSources(
        [['other', 'drums', 'bass']],
        group_names=['accompaniment']
    )
    assert isinstance(str(tfm), str)
    com = transforms.Compose([tfm, msa])
    assert msa.__class__.__name__ in str(com)
    assert tfm.__class__.__name__ in str(com)

    data = com(data)

    assert np.allclose(data['mix_magnitude'], np.abs(mix.stft()))
    assert data['metadata']['labels'] == [
        'bass', 'drums', 'other', 'vocals', 'accompaniment']

    mask_data = (
            data['source_magnitudes'] /
            np.maximum(
                data['mix_magnitude'][..., None],
                data['source_magnitudes'])
    )
    msa_scores = separate_and_evaluate(mix, data['sources'], mask_data)
    shape = mix.stft_data.shape + (len(sources),)
    mask_data = np.ones(shape)
    mix_scores = separate_and_evaluate(mix, data['sources'], mask_data)

    for key in msa_scores:
        if key in ['SI-SDR', 'SI-SIR', 'SI-SAR']:
            diff = np.array(msa_scores[key]) - np.array(mix_scores[key])
            assert diff.mean() > 10
コード例 #6
0
def test_transforms_magnitude_weights(mix_source_folder):
    dataset = nussl.datasets.MixSourceFolder(mix_source_folder)
    item = dataset[0]

    tfm = transforms.MagnitudeWeights()
    pytest.raises(TransformException, tfm, {'sources': []})

    item_from_mix = tfm(item)

    msa = transforms.MagnitudeSpectrumApproximation()
    item = tfm(msa(item))

    assert item['weights'].shape == item['mix_magnitude'].shape
    assert np.allclose(item_from_mix['weights'], item['weights'])
コード例 #7
0
def test_transform_cache(musdb_tracks):
    track = musdb_tracks[10]
    mix, sources = nussl.utils.musdb_track_to_audio_signals(track)

    data = {
        'mix': mix,
        'sources': sources,
        'metadata': {
            'labels': sorted(list(sources.keys()))
        },
        'index': 0
    }

    with tempfile.TemporaryDirectory() as tmpdir:
        tfm = transforms.Cache(os.path.join(tmpdir, 'cache'),
                               cache_size=2,
                               overwrite=True)

        _data_a = tfm(data)
        _info_a = tfm.info

        tfm.overwrite = False

        _data_b = tfm({'index': 0})

        pytest.raises(TransformException, tfm, {})
        pytest.raises(TransformException, tfm, {'index': 1})

        for key in _data_a:
            assert _data_a[key] == _data_b[key]

        com = transforms.Compose([
            transforms.MagnitudeSpectrumApproximation(),
            transforms.ToSeparationModel(),
            transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True),
        ])

        _data_a = com(data)
        com.transforms[-1].overwrite = False
        _data_b = com.transforms[-1]({'index': 0})

        for key in _data_a:
            if torch.is_tensor(_data_a[key]):
                assert torch.allclose(_data_a[key], _data_b[key])
            else:
                assert _data_a[key] == _data_b[key]
コード例 #8
0
def test_transforms_index_sources(mix_source_folder):
    dataset = nussl.datasets.MixSourceFolder(mix_source_folder)
    item = dataset[0]

    index = 1
    tfm = transforms.IndexSources('source_magnitudes', index)

    pytest.raises(TransformException, tfm, {'sources': []})
    pytest.raises(TransformException, tfm,
                  {'source_magnitudes': np.random.randn(100, 100, 1)})

    msa = transforms.MagnitudeSpectrumApproximation()
    msa_output = copy.deepcopy(msa(item))

    item = tfm(msa(item))

    assert (np.allclose(item['source_magnitudes'],
                        msa_output['source_magnitudes'][..., index, None]))
コード例 #9
0
def test_transform_get_excerpt(musdb_tracks):
    track = musdb_tracks[10]
    mix, sources = nussl.utils.musdb_track_to_audio_signals(track)

    msa = transforms.MagnitudeSpectrumApproximation()
    tdl = transforms.ToSeparationModel()
    excerpt_lengths = [400, 1000, 2000]
    for excerpt_length in excerpt_lengths:
        data = {
            'mix': mix,
            'sources': sources,
            'metadata': {'labels': []}
        }

        exc = transforms.GetExcerpt(excerpt_length=excerpt_length)
        assert isinstance(str(exc), str)
        com = transforms.Compose([msa, tdl, exc])

        data = com(data)

        for key in data:
            assert torch.is_tensor(data[key])
            assert data[key].shape[0] == excerpt_length
            assert data[key].shape[1] == mix.stft().shape[0]

        assert torch.mean((data['source_magnitudes'].sum(dim=-1) -
                           data['mix_magnitude']) ** 2).item() < 1e-5

        data = {
            'mix': mix,
            'sources': sources,
            'metadata': {'labels': []}
        }

        exc = transforms.GetExcerpt(excerpt_length=excerpt_length)
        assert isinstance(str(exc), str)
        com = transforms.Compose([msa, tdl])

        data = com(data)
        for key in data:
            data[key] = data[key].cpu().data.numpy()

        data = exc(data)

        for key in data:
            assert data[key].shape[0] == excerpt_length
            assert data[key].shape[1] == mix.stft().shape[0]

        assert np.mean((data['source_magnitudes'].sum(axis=-1) -
                        data['mix_magnitude']) ** 2) < 1e-5

        data = {
            'mix_magnitude': 'not an array or tensor'
        }

        pytest.raises(TransformException, exc, data)
    
    excerpt_lengths = [1009, 16000, 612140]
    ga = transforms.GetAudio()
    for excerpt_length in excerpt_lengths:
        data = {
            'mix': sum(sources.values()),
            'sources': sources,
            'metadata': {'labels': []}
        }

        exc = transforms.GetExcerpt(
            excerpt_length=excerpt_length,
            tf_keys = ['mix_audio', 'source_audio'],
            time_dim=1,
        )
        com = transforms.Compose([ga, tdl, exc])

        data = com(data)

        for key in data:
            assert torch.is_tensor(data[key])
            assert data[key].shape[1] == excerpt_length

        assert torch.allclose(
            data['source_audio'].sum(dim=-1), data['mix_audio'], atol=1e-3)
コード例 #10
0
ファイル: test_base_dataset.py プロジェクト: speechdnn/nussl
def test_dataset_base_with_caching(benchmark_audio, monkeypatch):
    keys = [benchmark_audio[k] for k in benchmark_audio]

    def dummy_get(self, folder):
        return keys

    monkeypatch.setattr(BaseDataset, 'get_items', dummy_get)
    monkeypatch.setattr(BaseDataset, 'process_item',
                        dummy_process_item_by_audio)

    with tempfile.TemporaryDirectory() as tmpdir:
        tfm = transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True)

        _dataset = BaseDataset('test', transform=tfm, cache_populated=False)
        assert tfm.cache_size == len(_dataset)

        _data_a = _dataset[0]
        _dataset.cache_populated = True
        pytest.raises(transforms.TransformException, _dataset.__getitem__,
                      1)  # haven't written to this yet!
        assert len(_dataset.post_cache_transforms.transforms) == 1
        _data_b = _dataset[0]

        for key in _data_a:
            assert _data_a[key] == _data_b[key]

        _dataset.cache_populated = False

        outputs_a = []
        outputs_b = []

        for i in range(len(_dataset)):
            outputs_a.append(_dataset[i])

        _dataset.cache_populated = True

        for i in range(len(_dataset)):
            outputs_b.append(_dataset[i])

        for _data_a, _data_b in zip(outputs_a, outputs_b):
            for key in _data_a:
                assert _data_a[key] == _data_b[key]

    with tempfile.TemporaryDirectory() as tmpdir:
        tfm = transforms.Compose([
            transforms.MagnitudeSpectrumApproximation(),
            transforms.ToSeparationModel(),
            transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True),
        ])
        _dataset = BaseDataset('test', transform=tfm, cache_populated=False)
        assert tfm.transforms[-1].cache_size == len(_dataset)
        _data_a = _dataset[0]

        _dataset.cache_populated = True
        pytest.raises(transforms.TransformException, _dataset.__getitem__,
                      1)  # haven't written to this yet!
        assert len(_dataset.post_cache_transforms.transforms) == 1
        _data_b = _dataset[0]

        for key in _data_a:
            if torch.is_tensor(_data_a[key]):
                assert torch.allclose(_data_a[key], _data_b[key])
            else:
                assert _data_a[key] == _data_b[key]

        _dataset.cache_populated = False

        outputs_a = []
        outputs_b = []

        for i in range(len(_dataset)):
            outputs_a.append(_dataset[i])

        _dataset.cache_populated = True

        for i in range(len(_dataset)):
            outputs_b.append(_dataset[i])

        for _data_a, _data_b in zip(outputs_a, outputs_b):
            for key in _data_a:
                if torch.is_tensor(_data_a[key]):
                    assert torch.allclose(_data_a[key], _data_b[key])
                else:
                    assert _data_a[key] == _data_b[key]

    for L in [100, 400, 1000]:
        with tempfile.TemporaryDirectory() as tmpdir:
            tfm = transforms.Compose([
                transforms.MagnitudeSpectrumApproximation(),
                transforms.ToSeparationModel(),
                transforms.Cache(os.path.join(tmpdir, 'cache'),
                                 overwrite=True),
                transforms.GetExcerpt(L)
            ])
            _dataset = BaseDataset('test',
                                   transform=tfm,
                                   cache_populated=False)
            assert tfm.transforms[-2].cache_size == len(_dataset)
            assert len(_dataset.post_cache_transforms.transforms) == 2

            for i in range(len(_dataset)):
                _ = _dataset[i]

            _dataset.cache_populated = True
            outputs = []
            for i in range(len(_dataset)):
                outputs.append(_dataset[i])

            for _output in outputs:
                for key, val in _output.items():
                    if torch.is_tensor(val):
                        assert val.shape[0] == L