예제 #1
0
def test_separation_model_save_and_load():
    model = SeparationModel(dpcl_config)

    tfms = datasets.transforms.Compose([
        datasets.transforms.PhaseSensitiveSpectrumApproximation(),
        datasets.transforms.ToSeparationModel(),
        datasets.transforms.Cache('tests/local/sep_model/cache')
    ])

    class DummyData:
        def __init__(self):
            self.stft_params = None
            self.sample_rate = None
            self.num_channels = None
            self.metadata = {'transforms': tfms}

    class DummyState:
        def __init__(self):
            self.epoch = 0
            self.epoch_length = 100
            self.max_epochs = 100
            self.output = None
            self.metrics = {}
            self.seed = None
            self.epoch_history = {}

    class DummyTrainer:
        def __init__(self):
            self.state = DummyState()

    dummy_data = DummyData()
    dummy_trainer = DummyTrainer()

    with tempfile.NamedTemporaryFile(suffix='.pth', delete=True) as tmp:

        loc = model.save(tmp.name,
                         train_data=dummy_data,
                         val_data=dummy_data,
                         trainer=dummy_trainer)
        new_model, metadata = SeparationModel.load(tmp.name)

        assert metadata['nussl_version'] == nussl.__version__

        new_model_params = {}
        old_model_params = {}

        for name, param in new_model.named_parameters():
            new_model_params[name] = param

        for name, param in model.named_parameters():
            old_model_params[name] = param

        for key in new_model_params:
            assert torch.allclose(new_model_params[key], old_model_params[key])
예제 #2
0
def test_separation_model_save():
    model = SeparationModel(dpcl_config)

    with tempfile.NamedTemporaryFile(suffix='.pth', delete=True) as tmp:
        loc = model.save(tmp.name)
        checkpoint = torch.load(loc)

        assert checkpoint['metadata']['nussl_version'] == nussl.__version__

        new_model = SeparationModel(checkpoint['config'])
        new_model.load_state_dict(checkpoint['state_dict'])

        new_model_params = {}
        old_model_params = {}

        for name, param in new_model.named_parameters():
            new_model_params[name] = param

        for name, param in model.named_parameters():
            old_model_params[name] = param

        for key in new_model_params:
            assert torch.allclose(new_model_params[key], old_model_params[key])