def test_separation_deep_clustering(overfit_model): model_path, item = overfit_model dpcl = separation.deep.DeepClustering(item['mix'], 2, model_path, mask_type='binary') dpcl.forward() # calls extract_features, for coverage item['mix'].write_audio_to_file('tests/local/dpcl_mix.wav') sources = item['sources'] estimates = dpcl() for i, e in enumerate(estimates): e.write_audio_to_file(f'tests/local/dpcl_overfit{i}.wav') evaluator = evaluation.BSSEvalScale(list(sources.values()), estimates, compute_permutation=True) scores = evaluator.evaluate() for key in evaluator.source_labels: for metric in ['SI-SDR', 'SI-SIR']: _score = scores[key][metric] for val in _score: assert val > SDR_CUTOFF dpcl.model.output_keys = [] pytest.raises(SeparationException, dpcl.extract_features)
def test_separation_deep_mask_estimation(overfit_model): model_path, item = overfit_model for mask_type in ['soft', 'binary']: dme = separation.deep.DeepMaskEstimation(item['mix'], model_path, mask_type=mask_type) pytest.raises(SeparationException, dme.confidence) item['mix'].write_audio_to_file('tests/local/dme_mix.wav') sources = item['sources'] estimates = dme() for i, e in enumerate(estimates): e.write_audio_to_file(f'tests/local/dme_overfit{i}.wav') evaluator = evaluation.BSSEvalScale(list(sources.values()), estimates, compute_permutation=True) scores = evaluator.evaluate() for key in evaluator.source_labels: for metric in ['SI-SDR', 'SI-SIR']: _score = scores[key][metric] for val in _score: assert val > SDR_CUTOFF confidence = dme.confidence() dme.model.output_keys = ['mask'] dme() pytest.raises(SeparationException, dme.confidence) dme.model.output_keys = [] pytest.raises(SeparationException, dme.run)
def separate_and_evaluate(item, masks): separator = separation.deep.DeepMaskEstimation(item['mix']) estimates = separator(masks) evaluator = evaluation.BSSEvalScale(list(item['sources'].values()), estimates, compute_permutation=True) scores = evaluator.evaluate() output_path = os.path.join(RESULTS_DIR, f"{item['mix'].file_name}.json") with open(output_path, 'w') as f: json.dump(scores, f)
def separate_and_evaluate(item_): separator = separation.benchmark.IdealRatioMask( item_['mix'], item_['sources'], approach=APPROACH, mask_type='soft', **KWARGS) estimates = separator() evaluator = evaluation.BSSEvalScale( list(item_['sources'].values()), estimates, compute_permutation=True) scores = evaluator.evaluate() output_path = os.path.join(RESULTS_DIR, f"{item_['mix'].file_name}.json") with open(output_path, 'w') as f: json.dump(scores, f)
def separate_and_evaluate(item): output_path = os.path.join( RESULTS_DIR, val['dir'], f"{item['mix'].file_name}.json") separator = separation.benchmark.IdealRatioMask( item['mix'], item['sources'], approach=val['approach'], mask_type='soft', **val['kwargs']) estimates = separator() evaluator = evaluation.BSSEvalScale( list(item['sources'].values()), estimates, compute_permutation=True) scores = evaluator.evaluate() with open(output_path, 'w') as f: json.dump(scores, f)
def separate_and_evaluate(item): separator = separation.benchmark.IdealBinaryMask(item['mix'], item['sources'], mask_type='binary') estimates = separator() evaluator = evaluation.BSSEvalScale(list(item['sources'].values()), estimates, compute_permutation=True) scores = evaluator.evaluate() output_path = os.path.join(RESULTS_DIR, f"{item['mix'].file_name}.json") with open(output_path, 'w') as f: json.dump(scores, f)
def separate_and_evaluate(mix, sources, mask_data): estimates = [] mask_data = normalize_masks(mask_data) for i in range(mask_data.shape[-1]): mask = SoftMask(mask_data[..., i]) estimate = mix.apply_mask(mask) estimate.istft() estimates.append(estimate) assert np.allclose( sum(estimates).audio_data, mix.audio_data, atol=stft_tol) sources = [sources[k] for k in sources] evaluator = evaluation.BSSEvalScale( sources, estimates) scores = evaluator.evaluate() return scores
def test_separation_deep_audio_estimation(overfit_audio_model): model_path, item = overfit_audio_model dae = separation.deep.DeepAudioEstimation(item['mix'], model_path) item['mix'].write_audio_to_file('tests/local/dae_mix.wav') sources = item['sources'] estimates = dae() for i, e in enumerate(estimates): e.write_audio_to_file(f'tests/local/dae_overfit{i}.wav') evaluator = evaluation.BSSEvalScale( list(sources.values()), estimates, compute_permutation=True) scores = evaluator.evaluate() for key in evaluator.source_labels: for metric in ['SI-SDR', 'SI-SIR']: _score = scores[key][metric] for val in _score: assert val > SDR_CUTOFF dae.model.output_keys = [] pytest.raises(SeparationException, dae.run)
def test_clustering_separation_base(scaper_folder, monkeypatch): dataset = datasets.Scaper(scaper_folder) item = dataset[5] mix = item['mix'] sources = item['sources'] pytest.raises(SeparationException, separation.ClusteringSeparationBase, mix, 2, clustering_type='not allowed') clustering_types = ( separation.base.clustering_separation_base.ALLOWED_CLUSTERING_TYPES) separator = separation.ClusteringSeparationBase(mix, 2) bad_features = np.ones(100) pytest.raises(SeparationException, separator.run, bad_features) good_features = np.stack([ np.abs(s.stft()) for _, s in sources.items() ], axis=-1) good_features = ( good_features == good_features.max(axis=-1, keepdims=True)) def dummy_extract(self): return good_features monkeypatch.setattr( separation.ClusteringSeparationBase, 'extract_features', dummy_extract) for clustering_type in clustering_types: separator = separation.ClusteringSeparationBase( mix, 2, clustering_type=clustering_type) pytest.raises(SeparationException, separator.confidence) estimates = separator() confidence = separator.confidence() assert confidence == 1.0 evaluator = evaluation.BSSEvalScale( list(sources.values()), estimates, compute_permutation=True) scores = evaluator.evaluate() for key in evaluator.source_labels: for metric in ['SI-SDR', 'SI-SIR']: _score = scores[key][metric] for val in _score: assert val > 5 separator = separation.ClusteringSeparationBase( mix, 2, clustering_type=clustering_type, mask_type='binary') estimates = separator() evaluator = evaluation.BSSEvalScale( list(sources.values()), estimates, compute_permutation=True) scores = evaluator.evaluate() for key in evaluator.source_labels: for metric in ['SI-SDR', 'SI-SIR']: _score = scores[key][metric] for val in _score: assert val > 9