예제 #1
0
def test_separation_deep_clustering(overfit_model):
    model_path, item = overfit_model
    dpcl = separation.deep.DeepClustering(item['mix'],
                                          2,
                                          model_path,
                                          mask_type='binary')
    dpcl.forward()  # calls extract_features, for coverage

    item['mix'].write_audio_to_file('tests/local/dpcl_mix.wav')
    sources = item['sources']
    estimates = dpcl()
    for i, e in enumerate(estimates):
        e.write_audio_to_file(f'tests/local/dpcl_overfit{i}.wav')

    evaluator = evaluation.BSSEvalScale(list(sources.values()),
                                        estimates,
                                        compute_permutation=True)
    scores = evaluator.evaluate()

    for key in evaluator.source_labels:
        for metric in ['SI-SDR', 'SI-SIR']:
            _score = scores[key][metric]
            for val in _score:
                assert val > SDR_CUTOFF

    dpcl.model.output_keys = []
    pytest.raises(SeparationException, dpcl.extract_features)
예제 #2
0
def test_separation_deep_mask_estimation(overfit_model):
    model_path, item = overfit_model
    for mask_type in ['soft', 'binary']:
        dme = separation.deep.DeepMaskEstimation(item['mix'],
                                                 model_path,
                                                 mask_type=mask_type)

        pytest.raises(SeparationException, dme.confidence)

        item['mix'].write_audio_to_file('tests/local/dme_mix.wav')
        sources = item['sources']
        estimates = dme()
        for i, e in enumerate(estimates):
            e.write_audio_to_file(f'tests/local/dme_overfit{i}.wav')

        evaluator = evaluation.BSSEvalScale(list(sources.values()),
                                            estimates,
                                            compute_permutation=True)
        scores = evaluator.evaluate()

        for key in evaluator.source_labels:
            for metric in ['SI-SDR', 'SI-SIR']:
                _score = scores[key][metric]
                for val in _score:
                    assert val > SDR_CUTOFF

        confidence = dme.confidence()
        dme.model.output_keys = ['mask']
        dme()
        pytest.raises(SeparationException, dme.confidence)

        dme.model.output_keys = []
        pytest.raises(SeparationException, dme.run)
예제 #3
0
def separate_and_evaluate(item, masks):
    separator = separation.deep.DeepMaskEstimation(item['mix'])
    estimates = separator(masks)

    evaluator = evaluation.BSSEvalScale(list(item['sources'].values()),
                                        estimates,
                                        compute_permutation=True)
    scores = evaluator.evaluate()
    output_path = os.path.join(RESULTS_DIR, f"{item['mix'].file_name}.json")
    with open(output_path, 'w') as f:
        json.dump(scores, f)
예제 #4
0
def separate_and_evaluate(item_):
    separator = separation.benchmark.IdealRatioMask(
        item_['mix'], item_['sources'], approach=APPROACH,
        mask_type='soft', **KWARGS)
    estimates = separator()

    evaluator = evaluation.BSSEvalScale(
        list(item_['sources'].values()), estimates, compute_permutation=True)
    scores = evaluator.evaluate()
    output_path = os.path.join(RESULTS_DIR, f"{item_['mix'].file_name}.json")
    with open(output_path, 'w') as f:
        json.dump(scores, f)
예제 #5
0
    def separate_and_evaluate(item):
        output_path = os.path.join(
            RESULTS_DIR, val['dir'], f"{item['mix'].file_name}.json")
        separator = separation.benchmark.IdealRatioMask(
            item['mix'], item['sources'], approach=val['approach'],
            mask_type='soft', **val['kwargs'])
        estimates = separator()

        evaluator = evaluation.BSSEvalScale(
            list(item['sources'].values()), estimates, compute_permutation=True)
        scores = evaluator.evaluate()
        with open(output_path, 'w') as f:
            json.dump(scores, f)
예제 #6
0
def separate_and_evaluate(item):
    separator = separation.benchmark.IdealBinaryMask(item['mix'],
                                                     item['sources'],
                                                     mask_type='binary')
    estimates = separator()

    evaluator = evaluation.BSSEvalScale(list(item['sources'].values()),
                                        estimates,
                                        compute_permutation=True)
    scores = evaluator.evaluate()
    output_path = os.path.join(RESULTS_DIR, f"{item['mix'].file_name}.json")
    with open(output_path, 'w') as f:
        json.dump(scores, f)
예제 #7
0
def separate_and_evaluate(mix, sources, mask_data):
    estimates = []
    mask_data = normalize_masks(mask_data)
    for i in range(mask_data.shape[-1]):
        mask = SoftMask(mask_data[..., i])
        estimate = mix.apply_mask(mask)
        estimate.istft()
        estimates.append(estimate)

    assert np.allclose(
        sum(estimates).audio_data, mix.audio_data, atol=stft_tol)

    sources = [sources[k] for k in sources]
    evaluator = evaluation.BSSEvalScale(
        sources, estimates)
    scores = evaluator.evaluate()
    return scores
예제 #8
0
def test_separation_deep_audio_estimation(overfit_audio_model):
    model_path, item = overfit_audio_model
    dae = separation.deep.DeepAudioEstimation(item['mix'], model_path)

    item['mix'].write_audio_to_file('tests/local/dae_mix.wav')
    sources = item['sources']
    estimates = dae()
    for i, e in enumerate(estimates):
        e.write_audio_to_file(f'tests/local/dae_overfit{i}.wav')

    evaluator = evaluation.BSSEvalScale(
        list(sources.values()), estimates, compute_permutation=True)
    scores = evaluator.evaluate()

    for key in evaluator.source_labels:
        for metric in ['SI-SDR', 'SI-SIR']:
            _score = scores[key][metric]
            for val in _score:
                assert val > SDR_CUTOFF

    dae.model.output_keys = []
    pytest.raises(SeparationException, dae.run)
예제 #9
0
def test_clustering_separation_base(scaper_folder, monkeypatch):
    dataset = datasets.Scaper(scaper_folder)
    item = dataset[5]
    mix = item['mix']
    sources = item['sources']

    pytest.raises(SeparationException, separation.ClusteringSeparationBase, 
        mix, 2, clustering_type='not allowed')

    clustering_types = (
        separation.base.clustering_separation_base.ALLOWED_CLUSTERING_TYPES)

    separator = separation.ClusteringSeparationBase(mix, 2)
    bad_features = np.ones(100)
    pytest.raises(SeparationException, separator.run, bad_features)

    good_features = np.stack([
            np.abs(s.stft()) for _, s in sources.items()
        ], axis=-1)

    good_features = (
        good_features == good_features.max(axis=-1, keepdims=True))

    def dummy_extract(self):
        return good_features

    monkeypatch.setattr(
        separation.ClusteringSeparationBase, 'extract_features', dummy_extract)

    for clustering_type in clustering_types:
        separator = separation.ClusteringSeparationBase(
            mix, 2, clustering_type=clustering_type)

        pytest.raises(SeparationException, separator.confidence)

        estimates = separator()
        confidence = separator.confidence()
        assert confidence == 1.0

        evaluator = evaluation.BSSEvalScale(
            list(sources.values()), estimates, compute_permutation=True)
        scores = evaluator.evaluate()

        for key in evaluator.source_labels:
            for metric in ['SI-SDR', 'SI-SIR']:
                _score = scores[key][metric]  
                for val in _score:
                    assert val > 5

        separator = separation.ClusteringSeparationBase(
            mix, 2, clustering_type=clustering_type, mask_type='binary')

        estimates = separator()

        evaluator = evaluation.BSSEvalScale(
            list(sources.values()), estimates, compute_permutation=True)
        scores = evaluator.evaluate()

        for key in evaluator.source_labels:
            for metric in ['SI-SDR', 'SI-SIR']:
                _score = scores[key][metric]  
                for val in _score:
                    assert val > 9