Пример #1
0
def test_transform_get_audio(mix_source_folder):
    dataset = nussl.datasets.MixSourceFolder(mix_source_folder)
    item = dataset[0]

    index = 1
    tfm = transforms.GetAudio()
    assert isinstance(str(tfm), str)
    pytest.raises(TransformException, tfm, {'sources': []})

    ga_output = tfm(item)

    assert np.allclose(
        ga_output['mix_audio'], item['mix'].audio_data)
    source_names = sorted(list(item['sources'].keys()))

    for i, key in enumerate(source_names):
        assert np.allclose(
            ga_output['source_audio'][..., i], item['sources'][key].audio_data)

    item.pop('sources')
    item.pop('source_audio')

    ga_output = tfm(item)

    assert np.allclose(
        ga_output['mix_audio'], item['mix'].audio_data)
Пример #2
0
def transform(stft_params: nussl.STFTParams,
              sample_rate: int,
              target_instrument,
              only_audio_signal: bool,
              mask_type: str = 'msa',
              audio_only: bool = False):
    """
    Builds transforms that get applied to
    training and validation datasets.

    Parameters
    ----------
    stft_params : nussl.STFTParams
        Parameters of STFT (see: signal).
    sample_rate : int
        Sample rate of audio signal (see: signal).
    target_instrument : str
        Which instrument to learn to separate out of
        a mixture.
    only_audio_signal : bool
        Whether to return only the audio signals, no
        tensors (useful for eval).
    mask_type : str, optional
        What type of masking to use. Either phase
        sensitive spectrum approx. (psa) or
        magnitude spectrum approx (msa), by default
        'msa'.
    audio_only : bool, optional
        Whether or not to only apply GetAudio in
        transform (don't compute STFTs).
    """
    tfm = []

    other_labels = [k for k in LABELS if k != target_instrument]
    tfm.append(nussl_tfm.SumSources([other_labels]))
    new_labels = [target_instrument] + tfm[-1].group_names
    new_labels = sorted(new_labels)

    if not only_audio_signal:
        if not audio_only:
            if mask_type == 'psa':
                tfm.append(nussl_tfm.PhaseSensitiveSpectrumApproximation())
            elif mask_type == 'msa':
                tfm.append(nussl_tfm.MagnitudeSpectrumApproximation())
            tfm.append(nussl_tfm.MagnitudeWeights())

        tfm.append(nussl_tfm.GetAudio())
        target_index = new_labels.index(target_instrument)
        tfm.append(nussl_tfm.IndexSources('source_magnitudes', target_index))

        tfm.append(nussl_tfm.ToSeparationModel())

    return nussl_tfm.Compose(tfm), new_labels
Пример #3
0
def one_item(scaper_folder):
    stft_params = nussl.STFTParams(window_length=512, hop_length=128)
    tfms = transforms.Compose([
        transforms.PhaseSensitiveSpectrumApproximation(),
        transforms.GetAudio(),
        transforms.ToSeparationModel()
    ])
    dataset = nussl.datasets.Scaper(scaper_folder,
                                    transform=tfms,
                                    stft_params=stft_params)
    i = np.random.randint(len(dataset))
    data = dataset[i]
    for k in data:
        # fake a batch dimension
        if torch.is_tensor(data[k]):
            data[k] = data[k].unsqueeze(0)
    yield data
Пример #4
0
def test_transform_get_excerpt(musdb_tracks):
    track = musdb_tracks[10]
    mix, sources = nussl.utils.musdb_track_to_audio_signals(track)

    msa = transforms.MagnitudeSpectrumApproximation()
    tdl = transforms.ToSeparationModel()
    excerpt_lengths = [400, 1000, 2000]
    for excerpt_length in excerpt_lengths:
        data = {
            'mix': mix,
            'sources': sources,
            'metadata': {'labels': []}
        }

        exc = transforms.GetExcerpt(excerpt_length=excerpt_length)
        assert isinstance(str(exc), str)
        com = transforms.Compose([msa, tdl, exc])

        data = com(data)

        for key in data:
            assert torch.is_tensor(data[key])
            assert data[key].shape[0] == excerpt_length
            assert data[key].shape[1] == mix.stft().shape[0]

        assert torch.mean((data['source_magnitudes'].sum(dim=-1) -
                           data['mix_magnitude']) ** 2).item() < 1e-5

        data = {
            'mix': mix,
            'sources': sources,
            'metadata': {'labels': []}
        }

        exc = transforms.GetExcerpt(excerpt_length=excerpt_length)
        assert isinstance(str(exc), str)
        com = transforms.Compose([msa, tdl])

        data = com(data)
        for key in data:
            data[key] = data[key].cpu().data.numpy()

        data = exc(data)

        for key in data:
            assert data[key].shape[0] == excerpt_length
            assert data[key].shape[1] == mix.stft().shape[0]

        assert np.mean((data['source_magnitudes'].sum(axis=-1) -
                        data['mix_magnitude']) ** 2) < 1e-5

        data = {
            'mix_magnitude': 'not an array or tensor'
        }

        pytest.raises(TransformException, exc, data)
    
    excerpt_lengths = [1009, 16000, 612140]
    ga = transforms.GetAudio()
    for excerpt_length in excerpt_lengths:
        data = {
            'mix': sum(sources.values()),
            'sources': sources,
            'metadata': {'labels': []}
        }

        exc = transforms.GetExcerpt(
            excerpt_length=excerpt_length,
            tf_keys = ['mix_audio', 'source_audio'],
            time_dim=1,
        )
        com = transforms.Compose([ga, tdl, exc])

        data = com(data)

        for key in data:
            assert torch.is_tensor(data[key])
            assert data[key].shape[1] == excerpt_length

        assert torch.allclose(
            data['source_audio'].sum(dim=-1), data['mix_audio'], atol=1e-3)