Ejemplo n.º 1
0
def music_mix_and_sources(musdb_tracks):
    dataset = datasets.MUSDB18(folder=musdb_tracks.root,
                               download=False,
                               transform=transforms.SumSources(
                                   [['drums', 'bass', 'other']]))
    item = dataset[0]
    return item['mix'], item['sources']
Ejemplo n.º 2
0
def test_transform_sum_sources(musdb_tracks):
    track = musdb_tracks[10]
    mix, sources = nussl.utils.musdb_track_to_audio_signals(track)

    data = {'mix': mix, 'sources': sources}

    groups = itertools.combinations(data['sources'].keys(), 3)

    tfm = None
    for group in groups:
        _data = copy.deepcopy(data)
        tfm = transforms.SumSources([group])
        _data = tfm(_data)
        for g in group:
            assert g not in _data['sources']
        assert '+'.join(group) in _data['sources']

        summed_sources = sum([sources[k] for k in group])

        assert np.allclose(_data['sources']['+'.join(group)].audio_data,
                           summed_sources.audio_data)

    pytest.raises(TransformException, tfm, {'no_key'})

    pytest.raises(TransformException, transforms.SumSources, 'test')

    pytest.raises(TransformException, transforms.SumSources,
                  [['vocals', 'test'], ['test2', 'test3']], ['mygroup'])
Ejemplo n.º 3
0
def transform(stft_params: nussl.STFTParams,
              sample_rate: int,
              target_instrument,
              only_audio_signal: bool,
              mask_type: str = 'msa',
              audio_only: bool = False):
    """
    Builds transforms that get applied to
    training and validation datasets.

    Parameters
    ----------
    stft_params : nussl.STFTParams
        Parameters of STFT (see: signal).
    sample_rate : int
        Sample rate of audio signal (see: signal).
    target_instrument : str
        Which instrument to learn to separate out of
        a mixture.
    only_audio_signal : bool
        Whether to return only the audio signals, no
        tensors (useful for eval).
    mask_type : str, optional
        What type of masking to use. Either phase
        sensitive spectrum approx. (psa) or
        magnitude spectrum approx (msa), by default
        'msa'.
    audio_only : bool, optional
        Whether or not to only apply GetAudio in
        transform (don't compute STFTs).
    """
    tfm = []

    other_labels = [k for k in LABELS if k != target_instrument]
    tfm.append(nussl_tfm.SumSources([other_labels]))
    new_labels = [target_instrument] + tfm[-1].group_names
    new_labels = sorted(new_labels)

    if not only_audio_signal:
        if not audio_only:
            if mask_type == 'psa':
                tfm.append(nussl_tfm.PhaseSensitiveSpectrumApproximation())
            elif mask_type == 'msa':
                tfm.append(nussl_tfm.MagnitudeSpectrumApproximation())
            tfm.append(nussl_tfm.MagnitudeWeights())

        tfm.append(nussl_tfm.GetAudio())
        target_index = new_labels.index(target_instrument)
        tfm.append(nussl_tfm.IndexSources('source_magnitudes', target_index))

        tfm.append(nussl_tfm.ToSeparationModel())

    return nussl_tfm.Compose(tfm), new_labels
Ejemplo n.º 4
0
def test_transform_compose(musdb_tracks):
    track = musdb_tracks[10]
    mix, sources = nussl.utils.musdb_track_to_audio_signals(track)

    data = {
        'mix': mix,
        'sources': sources,
        'metadata': {
            'labels': ['bass', 'drums', 'other', 'vocals']
        }
    }

    class _BadTransform(object):
        def __call__(self, data):
            return 'not a dictionary'

    com = transforms.Compose([_BadTransform()])
    pytest.raises(TransformException, com, data)

    msa = transforms.MagnitudeSpectrumApproximation()
    tfm = transforms.SumSources(
        [['other', 'drums', 'bass']],
        group_names=['accompaniment']
    )
    assert isinstance(str(tfm), str)
    com = transforms.Compose([tfm, msa])
    assert msa.__class__.__name__ in str(com)
    assert tfm.__class__.__name__ in str(com)

    data = com(data)

    assert np.allclose(data['mix_magnitude'], np.abs(mix.stft()))
    assert data['metadata']['labels'] == [
        'bass', 'drums', 'other', 'vocals', 'accompaniment']

    mask_data = (
            data['source_magnitudes'] /
            np.maximum(
                data['mix_magnitude'][..., None],
                data['source_magnitudes'])
    )
    msa_scores = separate_and_evaluate(mix, data['sources'], mask_data)
    shape = mix.stft_data.shape + (len(sources),)
    mask_data = np.ones(shape)
    mix_scores = separate_and_evaluate(mix, data['sources'], mask_data)

    for key in msa_scores:
        if key in ['SI-SDR', 'SI-SIR', 'SI-SAR']:
            diff = np.array(msa_scores[key]) - np.array(mix_scores[key])
            assert diff.mean() > 10
Ejemplo n.º 5
0
def test_dataset_hook_scaper_folder(scaper_folder):
    dataset = nussl.datasets.Scaper(scaper_folder)
    data = dataset[0]

    _sources = [data['sources'][k] for k in data['sources']]
    assert np.allclose(sum(_sources).audio_data, data['mix'].audio_data)

    for k in data['sources']:
        assert k.split('::')[0] in data['metadata']['labels']

    # make sure SumSources transform works
    tfm = transforms.SumSources(
        [['050', '051']],
        group_names=['both'],
    )

    data = tfm(data)

    for k in data['sources']:
        assert k.split('::')[0] in data['metadata']['labels']

    _sources = [data['sources'][k] for k in data['sources']]
    assert np.allclose(sum(_sources).audio_data, data['mix'].audio_data)