def music_mix_and_sources(musdb_tracks): dataset = datasets.MUSDB18(folder=musdb_tracks.root, download=False, transform=transforms.SumSources( [['drums', 'bass', 'other']])) item = dataset[0] return item['mix'], item['sources']
def test_transform_sum_sources(musdb_tracks): track = musdb_tracks[10] mix, sources = nussl.utils.musdb_track_to_audio_signals(track) data = {'mix': mix, 'sources': sources} groups = itertools.combinations(data['sources'].keys(), 3) tfm = None for group in groups: _data = copy.deepcopy(data) tfm = transforms.SumSources([group]) _data = tfm(_data) for g in group: assert g not in _data['sources'] assert '+'.join(group) in _data['sources'] summed_sources = sum([sources[k] for k in group]) assert np.allclose(_data['sources']['+'.join(group)].audio_data, summed_sources.audio_data) pytest.raises(TransformException, tfm, {'no_key'}) pytest.raises(TransformException, transforms.SumSources, 'test') pytest.raises(TransformException, transforms.SumSources, [['vocals', 'test'], ['test2', 'test3']], ['mygroup'])
def transform(stft_params: nussl.STFTParams, sample_rate: int, target_instrument, only_audio_signal: bool, mask_type: str = 'msa', audio_only: bool = False): """ Builds transforms that get applied to training and validation datasets. Parameters ---------- stft_params : nussl.STFTParams Parameters of STFT (see: signal). sample_rate : int Sample rate of audio signal (see: signal). target_instrument : str Which instrument to learn to separate out of a mixture. only_audio_signal : bool Whether to return only the audio signals, no tensors (useful for eval). mask_type : str, optional What type of masking to use. Either phase sensitive spectrum approx. (psa) or magnitude spectrum approx (msa), by default 'msa'. audio_only : bool, optional Whether or not to only apply GetAudio in transform (don't compute STFTs). """ tfm = [] other_labels = [k for k in LABELS if k != target_instrument] tfm.append(nussl_tfm.SumSources([other_labels])) new_labels = [target_instrument] + tfm[-1].group_names new_labels = sorted(new_labels) if not only_audio_signal: if not audio_only: if mask_type == 'psa': tfm.append(nussl_tfm.PhaseSensitiveSpectrumApproximation()) elif mask_type == 'msa': tfm.append(nussl_tfm.MagnitudeSpectrumApproximation()) tfm.append(nussl_tfm.MagnitudeWeights()) tfm.append(nussl_tfm.GetAudio()) target_index = new_labels.index(target_instrument) tfm.append(nussl_tfm.IndexSources('source_magnitudes', target_index)) tfm.append(nussl_tfm.ToSeparationModel()) return nussl_tfm.Compose(tfm), new_labels
def test_transform_compose(musdb_tracks): track = musdb_tracks[10] mix, sources = nussl.utils.musdb_track_to_audio_signals(track) data = { 'mix': mix, 'sources': sources, 'metadata': { 'labels': ['bass', 'drums', 'other', 'vocals'] } } class _BadTransform(object): def __call__(self, data): return 'not a dictionary' com = transforms.Compose([_BadTransform()]) pytest.raises(TransformException, com, data) msa = transforms.MagnitudeSpectrumApproximation() tfm = transforms.SumSources( [['other', 'drums', 'bass']], group_names=['accompaniment'] ) assert isinstance(str(tfm), str) com = transforms.Compose([tfm, msa]) assert msa.__class__.__name__ in str(com) assert tfm.__class__.__name__ in str(com) data = com(data) assert np.allclose(data['mix_magnitude'], np.abs(mix.stft())) assert data['metadata']['labels'] == [ 'bass', 'drums', 'other', 'vocals', 'accompaniment'] mask_data = ( data['source_magnitudes'] / np.maximum( data['mix_magnitude'][..., None], data['source_magnitudes']) ) msa_scores = separate_and_evaluate(mix, data['sources'], mask_data) shape = mix.stft_data.shape + (len(sources),) mask_data = np.ones(shape) mix_scores = separate_and_evaluate(mix, data['sources'], mask_data) for key in msa_scores: if key in ['SI-SDR', 'SI-SIR', 'SI-SAR']: diff = np.array(msa_scores[key]) - np.array(mix_scores[key]) assert diff.mean() > 10
def test_dataset_hook_scaper_folder(scaper_folder): dataset = nussl.datasets.Scaper(scaper_folder) data = dataset[0] _sources = [data['sources'][k] for k in data['sources']] assert np.allclose(sum(_sources).audio_data, data['mix'].audio_data) for k in data['sources']: assert k.split('::')[0] in data['metadata']['labels'] # make sure SumSources transform works tfm = transforms.SumSources( [['050', '051']], group_names=['both'], ) data = tfm(data) for k in data['sources']: assert k.split('::')[0] in data['metadata']['labels'] _sources = [data['sources'][k] for k in data['sources']] assert np.allclose(sum(_sources).audio_data, data['mix'].audio_data)