예제 #1
0
def test_loader_submodule():
    class SuperModule(nn.Module):
        """ nn.Module subclass that holds a model under self.whoever """
        def __init__(self, sub_model):
            super().__init__()
            self.whoever = sub_model

    model = SuperModule(nn.Sequential(nn.Linear(10, 10)))
    # Keys in state_dict will be whoever.0.weight, whoever.0.bias
    state_dict = model.state_dict()
    # We want to load it in model_2 (has keys 0.weight, 0.bias)
    model_2 = nn.Sequential(nn.Linear(10, 10))
    # Keys are not the same, torch raises an error for that.
    with pytest.raises(RuntimeError):
        model_2.load_state_dict(state_dict)
    # We can try loose model loading (assert it doesn't work)
    model_2.load_state_dict(state_dict, strict=False)
    assert not torch_utils.are_models_equal(model, model_2)
    # Apply our workaround torch_utils.load_state_dict_in and assert True.
    model_2 = torch_utils.load_state_dict_in(state_dict, model_2)
    assert torch_utils.are_models_equal(model, model_2)
예제 #2
0
def test_loader_module():
    model = nn.Sequential(nn.Linear(10, 10))
    state_dict = model.state_dict()
    model_2 = nn.Sequential(nn.Linear(10, 10))
    model_2 = torch_utils.load_state_dict_in(state_dict, model_2)
    assert torch_utils.are_models_equal(model, model_2)
예제 #3
0
def test_model_equal():
    model = nn.Sequential(nn.Linear(10, 10))
    assert torch_utils.are_models_equal(model, model)
    model_2 = nn.Sequential(nn.Linear(10, 10))
    assert not torch_utils.are_models_equal(model, model_2)