Exemple #1
0
def overfit_model(scaper_folder):
    nussl.utils.seed(0)
    tfms = datasets.transforms.Compose([
        datasets.transforms.PhaseSensitiveSpectrumApproximation(),
        datasets.transforms.MagnitudeWeights(),
        datasets.transforms.ToSeparationModel(),
        datasets.transforms.GetExcerpt(100)
    ])
    dataset = datasets.Scaper(scaper_folder, transform=tfms)
    dataset.items = [dataset.items[5]]

    dataloader = torch.utils.data.DataLoader(dataset)

    n_features = dataset[0]['mix_magnitude'].shape[1]
    config = ml.networks.builders.build_recurrent_chimera(
        n_features,
        50,
        1,
        True,
        0.3,
        20,
        'sigmoid',
        2,
        'sigmoid',
        normalization_class='InstanceNorm')
    model = ml.SeparationModel(config)
    model = model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=1e-2)

    loss_dictionary = {
        'DeepClusteringLoss': {
            'weight': 0.2
        },
        'PermutationInvariantLoss': {
            'args': ['L1Loss'],
            'weight': 0.8
        }
    }

    train_closure = ml.train.closures.TrainClosure(loss_dictionary, optimizer,
                                                   model)

    trainer, _ = ml.train.create_train_and_validation_engines(train_closure,
                                                              device=DEVICE)

    with tempfile.TemporaryDirectory() as tmpdir:
        _dir = fix_dir if fix_dir else tmpdir

        ml.train.add_stdout_handler(trainer)
        ml.train.add_validate_and_checkpoint(_dir, model, optimizer, dataset,
                                             trainer)

        trainer.run(dataloader, max_epochs=1, epoch_length=EPOCH_LENGTH)

        model_path = os.path.join(trainer.state.output_folder, 'checkpoints',
                                  'best.model.pth')
        yield model_path, dataset.process_item(dataset.items[0])
Exemple #2
0
def overfit_audio_model(scaper_folder):
    nussl.utils.seed(0)
    tfms = datasets.transforms.Compose([
        datasets.transforms.GetAudio(),
        datasets.transforms.ToSeparationModel(),
        datasets.transforms.GetExcerpt(
            32000, time_dim=1, tf_keys=['mix_audio', 'source_audio'])
    ])
    dataset = datasets.Scaper(
        scaper_folder, transform=tfms)
    dataset.items = [dataset.items[5]]

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=1)

    config = ml.networks.builders.build_recurrent_end_to_end(
        256, 256, 64, 'sqrt_hann', 50, 2, 
        True, 0.3, 2, 'sigmoid', num_audio_channels=1, 
        mask_complex=False, rnn_type='lstm', 
        mix_key='mix_audio')

    model = ml.SeparationModel(config)
    model = model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    loss_dictionary = {
        'PermutationInvariantLoss': {
            'args': ['SISDRLoss'],
            'weight': 1.0,
            'keys': {'audio': 'estimates', 'source_audio': 'targets'}
        }
    }

    train_closure = ml.train.closures.TrainClosure(
        loss_dictionary, optimizer, model)
    
    trainer, _ = ml.train.create_train_and_validation_engines(
        train_closure, device=DEVICE
    )

    with tempfile.TemporaryDirectory() as tmpdir:
        _dir = os.path.join(fix_dir, 'dae') if fix_dir else tmpdir

        ml.train.add_stdout_handler(trainer)
        ml.train.add_validate_and_checkpoint(
            _dir, model, optimizer, dataset, trainer)
        ml.train.add_progress_bar_handler(trainer)

        trainer.run(dataloader, max_epochs=1, epoch_length=EPOCH_LENGTH)

        model_path = os.path.join(
            trainer.state.output_folder, 'checkpoints', 'best.model.pth')
        yield model_path, dataset.process_item(dataset.items[0])
Exemple #3
0
def test_clustering_separation_base(scaper_folder, monkeypatch):
    dataset = datasets.Scaper(scaper_folder)
    item = dataset[5]
    mix = item['mix']
    sources = item['sources']

    pytest.raises(SeparationException, separation.ClusteringSeparationBase, 
        mix, 2, clustering_type='not allowed')

    clustering_types = (
        separation.base.clustering_separation_base.ALLOWED_CLUSTERING_TYPES)

    separator = separation.ClusteringSeparationBase(mix, 2)
    bad_features = np.ones(100)
    pytest.raises(SeparationException, separator.run, bad_features)

    good_features = np.stack([
            np.abs(s.stft()) for _, s in sources.items()
        ], axis=-1)

    good_features = (
        good_features == good_features.max(axis=-1, keepdims=True))

    def dummy_extract(self):
        return good_features

    monkeypatch.setattr(
        separation.ClusteringSeparationBase, 'extract_features', dummy_extract)

    for clustering_type in clustering_types:
        separator = separation.ClusteringSeparationBase(
            mix, 2, clustering_type=clustering_type)

        pytest.raises(SeparationException, separator.confidence)

        estimates = separator()
        confidence = separator.confidence()
        assert confidence == 1.0

        evaluator = evaluation.BSSEvalScale(
            list(sources.values()), estimates, compute_permutation=True)
        scores = evaluator.evaluate()

        for key in evaluator.source_labels:
            for metric in ['SI-SDR', 'SI-SIR']:
                _score = scores[key][metric]  
                for val in _score:
                    assert val > 5

        separator = separation.ClusteringSeparationBase(
            mix, 2, clustering_type=clustering_type, mask_type='binary')

        estimates = separator()

        evaluator = evaluation.BSSEvalScale(
            list(sources.values()), estimates, compute_permutation=True)
        scores = evaluator.evaluate()

        for key in evaluator.source_labels:
            for metric in ['SI-SDR', 'SI-SIR']:
                _score = scores[key][metric]  
                for val in _score:
                    assert val > 9  
Exemple #4
0
def mix_and_sources(scaper_folder):
    dataset = datasets.Scaper(scaper_folder)
    item = dataset[0]
    return item['mix'], item['sources']