def test_with_bias(self, random_state): torch.manual_seed(random_state) # 2 * 4 | in_features x out_features w = torch.randn(4, 2) b = torch.randn(4) graph = Data( edge_index=torch.tensor([ [0, 0, 0, 0, 1, 1, 1, 1, 6, 7, 8, 9], [2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5], ]), edge_features=torch.tensor([ [w[0, 0]], [w[1, 0]], [w[2, 0]], [w[3, 0]], [w[0, 1]], [w[1, 1]], [w[2, 1]], [w[3, 1]], [b[0]], [b[1]], [b[2]], [b[3]], ]), ) model = Linear.to_module(graph) model_true = torch.nn.Linear(2, 4, bias=True) model_true.weight = torch.nn.Parameter(w) model_true.bias = torch.nn.Parameter(b) assert isinstance(model_true, torch.nn.Linear) assert torch.equal(model_true.weight, w) assert torch.equal(model.bias, b)
def test_without_bias(self, random_state): torch.manual_seed(random_state) # 2 * 4 | in_features x out_features w = torch.randn(4, 2) graph = Data( edge_index=torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [2, 3, 4, 5, 2, 3, 4, 5]]), edge_features=torch.tensor([ [w[0, 0]], [w[1, 0]], [w[2, 0]], [w[3, 0]], [w[0, 1]], [w[1, 1]], [w[2, 1]], [w[3, 1]], ]), ) model = Linear.to_module(graph) model_true = torch.nn.Linear(2, 4, bias=False) model_true.weight = torch.nn.Parameter(w) assert isinstance(model_true, torch.nn.Linear) assert torch.equal(model_true.weight, model.weight) assert model.bias is None
def test_round_trip(self, in_features, out_features, bias, random_state): torch.manual_seed(random_state) input_model = torch.nn.Linear(in_features, out_features, bias=bias) output_model = Linear.to_module(Linear.to_graph(input_model)) assert torch.equal(input_model.weight, output_model.weight) if bias: assert torch.equal(input_model.bias, output_model.bias)