def test_load_lazy_linear(dim1, dim2): lin1 = Linear(dim1, 32) lin2 = Linear(dim1, 32) lin2.load_state_dict(lin1.state_dict()) if dim1 != -1: assert torch.allclose(lin1.weight, lin2.weight) assert torch.allclose(lin1.bias, lin2.bias) assert not hasattr(lin1, '_hook') assert not hasattr(lin2, '_hook') else: assert isinstance(lin1.weight, UninitializedParameter) assert isinstance(lin2.weight, UninitializedParameter) assert hasattr(lin1, '_hook') assert hasattr(lin2, '_hook')
def test_load_lazy_linear(dim1, dim2): lin1 = Linear(dim1, 32) lin2 = Linear(dim1, 32) lin2.load_state_dict(lin1.state_dict()) if dim1 != -1: assert torch.allclose(lin1.weight, lin2.weight) assert torch.allclose(lin1.bias, lin2.bias) assert not hasattr(lin1, '_hook') assert not hasattr(lin2, '_hook') else: assert isinstance(lin1.weight, UninitializedParameter) assert isinstance(lin2.weight, UninitializedParameter) assert hasattr(lin1, '_hook') assert hasattr(lin2, '_hook') with pytest.raises(RuntimeError, match="in state_dict"): lin1.load_state_dict({}, strict=True) lin1.load_state_dict({}, strict=False)