Ejemplo n.º 1
0
def test_separation_model_extra_modules(one_item):
    with tempfile.NamedTemporaryFile(suffix='.json', delete=True) as tmp:
        dpcl_config['modules']['test'] = {'class': 'MyModule'}
        dpcl_config['connections'].append(('test', ('mix_magnitude', {
            'embedding': 'embedding',
            'flip': False
        })))
        dpcl_config['output'].append('test')
        with open(tmp.name, 'w') as f:
            json.dump(dpcl_config, f)
        configs = [dpcl_config, tmp.name, json.dumps(dpcl_config)]

        nussl.ml.register_module(MyModule)

        for config in configs:
            model = SeparationModel(config)
            output = model(one_item)

            assert (output['embedding'].shape == (
                one_item['mix_magnitude'].shape + (20, )))

            assert torch.allclose(one_item['mix_magnitude'], output['test'])

            model = SeparationModel(config)
            copy_one_item = copy.deepcopy(one_item)
            output = model(copy_one_item, flip=True)

            assert torch.allclose(one_item['mix_magnitude'], -output['test'])
Ejemplo n.º 2
0
def test_separation_model_save_and_load():
    model = SeparationModel(dpcl_config)

    tfms = datasets.transforms.Compose([
        datasets.transforms.PhaseSensitiveSpectrumApproximation(),
        datasets.transforms.ToSeparationModel(),
        datasets.transforms.Cache('tests/local/sep_model/cache')
    ])

    class DummyData:
        def __init__(self):
            self.stft_params = None
            self.sample_rate = None
            self.num_channels = None
            self.metadata = {'transforms': tfms}

    class DummyState:
        def __init__(self):
            self.epoch = 0
            self.epoch_length = 100
            self.max_epochs = 100
            self.output = None
            self.metrics = {}
            self.seed = None
            self.epoch_history = {}

    class DummyTrainer:
        def __init__(self):
            self.state = DummyState()

    dummy_data = DummyData()
    dummy_trainer = DummyTrainer()

    with tempfile.NamedTemporaryFile(suffix='.pth', delete=True) as tmp:

        loc = model.save(tmp.name,
                         train_data=dummy_data,
                         val_data=dummy_data,
                         trainer=dummy_trainer)
        new_model, metadata = SeparationModel.load(tmp.name)

        assert metadata['nussl_version'] == nussl.__version__

        new_model_params = {}
        old_model_params = {}

        for name, param in new_model.named_parameters():
            new_model_params[name] = param

        for name, param in model.named_parameters():
            old_model_params[name] = param

        for key in new_model_params:
            assert torch.allclose(new_model_params[key], old_model_params[key])
Ejemplo n.º 3
0
def test_separation_end_to_end(one_item):
    for c in [end_to_end_real_config, end_to_end_complex_config]:
        with tempfile.NamedTemporaryFile(suffix='.json', delete=True) as tmp:
            with open(tmp.name, 'w') as f:
                json.dump(c, f)
            configs = [c, tmp.name, json.dumps(c)]

            for config in configs:
                model = SeparationModel(config)
                output = model(one_item)

                assert (
                    output['audio'].shape == one_item['source_audio'].shape)
Ejemplo n.º 4
0
def test_separation_dprnn(one_item):
    # dprnn network
    with tempfile.NamedTemporaryFile(suffix='.json', delete=True) as tmp:
        with open(tmp.name, 'w') as f:
            json.dump(dual_path_recurrent_config, f)
        configs = [
            dual_path_recurrent_config, tmp.name,
            json.dumps(dual_path_recurrent_config)
        ]

        for config in configs:
            model = SeparationModel(config, verbose=True)
            output = model(one_item)

            assert (output['audio'].shape == one_item['source_audio'].shape)
Ejemplo n.º 5
0
def test_separation_model_dpcl(one_item):
    n_features = one_item['mix_magnitude'].shape[2]

    # dpcl network
    with tempfile.NamedTemporaryFile(suffix='.json', delete=True) as tmp:
        with open(tmp.name, 'w') as f:
            json.dump(dpcl_config, f)
        configs = [dpcl_config, tmp.name, json.dumps(dpcl_config)]

        for config in configs:
            model = SeparationModel(config)
            output = model(one_item)

            assert (output['embedding'].shape == (
                one_item['mix_magnitude'].shape + (20, )))
Ejemplo n.º 6
0
def test_separation_model_save():
    model = SeparationModel(dpcl_config)

    with tempfile.NamedTemporaryFile(suffix='.pth', delete=True) as tmp:
        loc = model.save(tmp.name)
        checkpoint = torch.load(loc)

        assert checkpoint['metadata']['nussl_version'] == nussl.__version__

        new_model = SeparationModel(checkpoint['config'])
        new_model.load_state_dict(checkpoint['state_dict'])

        new_model_params = {}
        old_model_params = {}

        for name, param in new_model.named_parameters():
            new_model_params[name] = param

        for name, param in model.named_parameters():
            old_model_params[name] = param

        for key in new_model_params:
            assert torch.allclose(new_model_params[key], old_model_params[key])
Ejemplo n.º 7
0
def test_separation_model_split(one_item):
    n_features = one_item['mix_magnitude'].shape[2]

    with tempfile.NamedTemporaryFile(suffix='.json', delete=True) as tmp:
        with open(tmp.name, 'w') as f:
            json.dump(split_config, f)
        configs = [split_config, tmp.name, json.dumps(split_config)]

        for config in configs:
            model = SeparationModel(config)
            output = model(one_item)

            assert (output['estimates'].shape == (
                one_item['mix_magnitude'].shape + (2, )))
            assert (torch.allclose(output['estimates'].sum(dim=-1),
                                   one_item['mix_magnitude']))

            assert (output['split:0'].shape[2] == 100)
            assert (output['split:1'].shape[2] == 157)
Ejemplo n.º 8
0
def test_separation_model_open_unmix_like(one_item):
    n_features = one_item['mix_magnitude'].shape[2]

    with tempfile.NamedTemporaryFile(suffix='.json', delete=True) as tmp:
        with open(tmp.name, 'w') as f:
            json.dump(open_unmix_like_config, f)
        configs = [
            open_unmix_like_config, tmp.name,
            json.dumps(open_unmix_like_config)
        ]

        for config in configs:
            model = SeparationModel(config)
            output = model(one_item)

            assert (output['estimates'].shape == (
                one_item['mix_magnitude'].shape + (2, )))
            assert (output['mask'].shape == (one_item['mix_magnitude'].shape +
                                             (2, )))
            assert (output['embedding'].shape == (
                one_item['mix_magnitude'].shape + (20, )))
Ejemplo n.º 9
0
def test_separation_model_mask_inference(one_item):
    n_features = one_item['mix_magnitude'].shape[2]

    with tempfile.NamedTemporaryFile(suffix='.json', delete=True) as tmp:
        with open(tmp.name, 'w') as f:
            json.dump(mi_config, f)
        configs = [mi_config, tmp.name, json.dumps(mi_config)]

        for config in configs:
            model = SeparationModel(config)

            bad_item = copy.deepcopy(one_item)
            bad_item.pop('mix_magnitude')
            pytest.raises(ValueError, model, bad_item)

            output = model(one_item)

            assert (output['estimates'].shape == (
                one_item['mix_magnitude'].shape + (2, )))
            assert (torch.allclose(output['estimates'].sum(dim=-1),
                                   one_item['mix_magnitude']))
Ejemplo n.º 10
0
def test_separation_model_gmm_unfold(one_item):
    n_features = one_item['mix_magnitude'].shape[2]

    with tempfile.NamedTemporaryFile(suffix='.json', delete=True) as tmp:
        with open(tmp.name, 'w') as f:
            json.dump(gmm_unfold_config, f)
        configs = [gmm_unfold_config, tmp.name, json.dumps(gmm_unfold_config)]

        for config in configs:
            model = SeparationModel(config)
            one_item['init_means'] = torch.randn(
                one_item['mix_magnitude'].shape[0], 2,
                20).to(one_item['mix_magnitude'].device)
            output = model(one_item)

            assert (output['estimates'].shape == (
                one_item['mix_magnitude'].shape + (2, )))
            assert (torch.allclose(output['estimates'].sum(dim=-1),
                                   one_item['mix_magnitude']))

            assert (output['embedding'].shape == (
                one_item['mix_magnitude'].shape + (20, )))
Ejemplo n.º 11
0
def test_separation_model_repr_and_verbose(one_item):
    model = SeparationModel(end_to_end_real_config, verbose=True)
    print(model)
    model(one_item)