def test_transform_to_separation_model(musdb_tracks): track = musdb_tracks[10] mix, sources = nussl.utils.musdb_track_to_audio_signals(track) data = { 'mix': mix, 'sources': sources, 'metadata': {'labels': []} } msa = transforms.MagnitudeSpectrumApproximation() tdl = transforms.ToSeparationModel() assert tdl.__class__.__name__ in str(tdl) com = transforms.Compose([msa, tdl]) data = com(data) accepted_keys = ['mix_magnitude', 'source_magnitudes'] rejected_keys = ['mix', 'sources', 'metadata'] for a in accepted_keys: assert a in data for r in rejected_keys: assert r not in data for key in data: assert torch.is_tensor(data[key]) assert data[key].shape[0] == mix.stft().shape[1] assert data[key].shape[1] == mix.stft().shape[0]
def transform(stft_params: nussl.STFTParams, sample_rate: int, target_instrument, only_audio_signal: bool, mask_type: str = 'msa', audio_only: bool = False): """ Builds transforms that get applied to training and validation datasets. Parameters ---------- stft_params : nussl.STFTParams Parameters of STFT (see: signal). sample_rate : int Sample rate of audio signal (see: signal). target_instrument : str Which instrument to learn to separate out of a mixture. only_audio_signal : bool Whether to return only the audio signals, no tensors (useful for eval). mask_type : str, optional What type of masking to use. Either phase sensitive spectrum approx. (psa) or magnitude spectrum approx (msa), by default 'msa'. audio_only : bool, optional Whether or not to only apply GetAudio in transform (don't compute STFTs). """ tfm = [] other_labels = [k for k in LABELS if k != target_instrument] tfm.append(nussl_tfm.SumSources([other_labels])) new_labels = [target_instrument] + tfm[-1].group_names new_labels = sorted(new_labels) if not only_audio_signal: if not audio_only: if mask_type == 'psa': tfm.append(nussl_tfm.PhaseSensitiveSpectrumApproximation()) elif mask_type == 'msa': tfm.append(nussl_tfm.MagnitudeSpectrumApproximation()) tfm.append(nussl_tfm.MagnitudeWeights()) tfm.append(nussl_tfm.GetAudio()) target_index = new_labels.index(target_instrument) tfm.append(nussl_tfm.IndexSources('source_magnitudes', target_index)) tfm.append(nussl_tfm.ToSeparationModel()) return nussl_tfm.Compose(tfm), new_labels
def one_item(scaper_folder): stft_params = nussl.STFTParams(window_length=512, hop_length=128) tfms = transforms.Compose([ transforms.PhaseSensitiveSpectrumApproximation(), transforms.GetAudio(), transforms.ToSeparationModel() ]) dataset = nussl.datasets.Scaper(scaper_folder, transform=tfms, stft_params=stft_params) i = np.random.randint(len(dataset)) data = dataset[i] for k in data: # fake a batch dimension if torch.is_tensor(data[k]): data[k] = data[k].unsqueeze(0) yield data
def test_transform_cache(musdb_tracks): track = musdb_tracks[10] mix, sources = nussl.utils.musdb_track_to_audio_signals(track) data = { 'mix': mix, 'sources': sources, 'metadata': { 'labels': sorted(list(sources.keys())) }, 'index': 0 } with tempfile.TemporaryDirectory() as tmpdir: tfm = transforms.Cache(os.path.join(tmpdir, 'cache'), cache_size=2, overwrite=True) _data_a = tfm(data) _info_a = tfm.info tfm.overwrite = False _data_b = tfm({'index': 0}) pytest.raises(TransformException, tfm, {}) pytest.raises(TransformException, tfm, {'index': 1}) for key in _data_a: assert _data_a[key] == _data_b[key] com = transforms.Compose([ transforms.MagnitudeSpectrumApproximation(), transforms.ToSeparationModel(), transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True), ]) _data_a = com(data) com.transforms[-1].overwrite = False _data_b = com.transforms[-1]({'index': 0}) for key in _data_a: if torch.is_tensor(_data_a[key]): assert torch.allclose(_data_a[key], _data_b[key]) else: assert _data_a[key] == _data_b[key]
def test_transform_get_excerpt(musdb_tracks): track = musdb_tracks[10] mix, sources = nussl.utils.musdb_track_to_audio_signals(track) msa = transforms.MagnitudeSpectrumApproximation() tdl = transforms.ToSeparationModel() excerpt_lengths = [400, 1000, 2000] for excerpt_length in excerpt_lengths: data = { 'mix': mix, 'sources': sources, 'metadata': {'labels': []} } exc = transforms.GetExcerpt(excerpt_length=excerpt_length) assert isinstance(str(exc), str) com = transforms.Compose([msa, tdl, exc]) data = com(data) for key in data: assert torch.is_tensor(data[key]) assert data[key].shape[0] == excerpt_length assert data[key].shape[1] == mix.stft().shape[0] assert torch.mean((data['source_magnitudes'].sum(dim=-1) - data['mix_magnitude']) ** 2).item() < 1e-5 data = { 'mix': mix, 'sources': sources, 'metadata': {'labels': []} } exc = transforms.GetExcerpt(excerpt_length=excerpt_length) assert isinstance(str(exc), str) com = transforms.Compose([msa, tdl]) data = com(data) for key in data: data[key] = data[key].cpu().data.numpy() data = exc(data) for key in data: assert data[key].shape[0] == excerpt_length assert data[key].shape[1] == mix.stft().shape[0] assert np.mean((data['source_magnitudes'].sum(axis=-1) - data['mix_magnitude']) ** 2) < 1e-5 data = { 'mix_magnitude': 'not an array or tensor' } pytest.raises(TransformException, exc, data) excerpt_lengths = [1009, 16000, 612140] ga = transforms.GetAudio() for excerpt_length in excerpt_lengths: data = { 'mix': sum(sources.values()), 'sources': sources, 'metadata': {'labels': []} } exc = transforms.GetExcerpt( excerpt_length=excerpt_length, tf_keys = ['mix_audio', 'source_audio'], time_dim=1, ) com = transforms.Compose([ga, tdl, exc]) data = com(data) for key in data: assert torch.is_tensor(data[key]) assert data[key].shape[1] == excerpt_length assert torch.allclose( data['source_audio'].sum(dim=-1), data['mix_audio'], atol=1e-3)
def test_dataset_base_with_caching(benchmark_audio, monkeypatch): keys = [benchmark_audio[k] for k in benchmark_audio] def dummy_get(self, folder): return keys monkeypatch.setattr(BaseDataset, 'get_items', dummy_get) monkeypatch.setattr(BaseDataset, 'process_item', dummy_process_item_by_audio) with tempfile.TemporaryDirectory() as tmpdir: tfm = transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True) _dataset = BaseDataset('test', transform=tfm, cache_populated=False) assert tfm.cache_size == len(_dataset) _data_a = _dataset[0] _dataset.cache_populated = True pytest.raises(transforms.TransformException, _dataset.__getitem__, 1) # haven't written to this yet! assert len(_dataset.post_cache_transforms.transforms) == 1 _data_b = _dataset[0] for key in _data_a: assert _data_a[key] == _data_b[key] _dataset.cache_populated = False outputs_a = [] outputs_b = [] for i in range(len(_dataset)): outputs_a.append(_dataset[i]) _dataset.cache_populated = True for i in range(len(_dataset)): outputs_b.append(_dataset[i]) for _data_a, _data_b in zip(outputs_a, outputs_b): for key in _data_a: assert _data_a[key] == _data_b[key] with tempfile.TemporaryDirectory() as tmpdir: tfm = transforms.Compose([ transforms.MagnitudeSpectrumApproximation(), transforms.ToSeparationModel(), transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True), ]) _dataset = BaseDataset('test', transform=tfm, cache_populated=False) assert tfm.transforms[-1].cache_size == len(_dataset) _data_a = _dataset[0] _dataset.cache_populated = True pytest.raises(transforms.TransformException, _dataset.__getitem__, 1) # haven't written to this yet! assert len(_dataset.post_cache_transforms.transforms) == 1 _data_b = _dataset[0] for key in _data_a: if torch.is_tensor(_data_a[key]): assert torch.allclose(_data_a[key], _data_b[key]) else: assert _data_a[key] == _data_b[key] _dataset.cache_populated = False outputs_a = [] outputs_b = [] for i in range(len(_dataset)): outputs_a.append(_dataset[i]) _dataset.cache_populated = True for i in range(len(_dataset)): outputs_b.append(_dataset[i]) for _data_a, _data_b in zip(outputs_a, outputs_b): for key in _data_a: if torch.is_tensor(_data_a[key]): assert torch.allclose(_data_a[key], _data_b[key]) else: assert _data_a[key] == _data_b[key] for L in [100, 400, 1000]: with tempfile.TemporaryDirectory() as tmpdir: tfm = transforms.Compose([ transforms.MagnitudeSpectrumApproximation(), transforms.ToSeparationModel(), transforms.Cache(os.path.join(tmpdir, 'cache'), overwrite=True), transforms.GetExcerpt(L) ]) _dataset = BaseDataset('test', transform=tfm, cache_populated=False) assert tfm.transforms[-2].cache_size == len(_dataset) assert len(_dataset.post_cache_transforms.transforms) == 2 for i in range(len(_dataset)): _ = _dataset[i] _dataset.cache_populated = True outputs = [] for i in range(len(_dataset)): outputs.append(_dataset[i]) for _output in outputs: for key, val in _output.items(): if torch.is_tensor(val): assert val.shape[0] == L