def test_round_trip(self, in_features, out_features, bias, random_state): torch.manual_seed(random_state) input_model = torch.nn.Linear(in_features, out_features, bias=bias) output_model = Linear.to_module(Linear.to_graph(input_model)) assert torch.equal(input_model.weight, output_model.weight) if bias: assert torch.equal(input_model.bias, output_model.bias)
def test_with_bias(self, random_state): torch.manual_seed(random_state) # 2 * 4 | in_features x out_features w = torch.randn(4, 2) b = torch.randn(4) graph = Data( edge_index=torch.tensor([ [0, 0, 0, 0, 1, 1, 1, 1, 6, 7, 8, 9], [2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5], ]), edge_features=torch.tensor([ [w[0, 0]], [w[1, 0]], [w[2, 0]], [w[3, 0]], [w[0, 1]], [w[1, 1]], [w[2, 1]], [w[3, 1]], [b[0]], [b[1]], [b[2]], [b[3]], ]), ) model = Linear.to_module(graph) model_true = torch.nn.Linear(2, 4, bias=True) model_true.weight = torch.nn.Parameter(w) model_true.bias = torch.nn.Parameter(b) assert isinstance(model_true, torch.nn.Linear) assert torch.equal(model_true.weight, w) assert torch.equal(model.bias, b)
def test_without_bias(self, random_state): torch.manual_seed(random_state) # 2 * 4 | in_features x out_features w = torch.randn(4, 2) graph = Data( edge_index=torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [2, 3, 4, 5, 2, 3, 4, 5]]), edge_features=torch.tensor([ [w[0, 0]], [w[1, 0]], [w[2, 0]], [w[3, 0]], [w[0, 1]], [w[1, 1]], [w[2, 1]], [w[3, 1]], ]), ) model = Linear.to_module(graph) model_true = torch.nn.Linear(2, 4, bias=False) model_true.weight = torch.nn.Parameter(w) assert isinstance(model_true, torch.nn.Linear) assert torch.equal(model_true.weight, model.weight) assert model.bias is None
def test_same_as_linear(self, in_features, out_features, bias): module = torch.nn.Linear(in_features, out_features, bias=bias) x_linear = Linear.to_graph(module) x_mlp = MLP.to_graph(torch.nn.Sequential(module)) assert x_linear.num_nodes == x_mlp.num_nodes assert torch.allclose(x_linear.x, x_mlp.x) assert torch.allclose(x_linear.edge_index, x_mlp.edge_index) assert torch.allclose(x_linear.edge_features, x_mlp.edge_features)
def test_basic_with_bias(self, node_strategy, random_state): in_features = 3 out_features = 2 target = torch.tensor([[3.43]]) torch.manual_seed(random_state) model = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=True) graph = Linear.to_graph(model, target=target, node_strategy=node_strategy) # Checks if node_strategy is None: x_true = None elif node_strategy == "constant": x_true = torch.tensor([1, 1, 1, 1, 1, 1, 1], dtype=torch.float)[:, None] elif node_strategy == "proportional": x_true = torch.tensor([1, 1, 1, 1 / 4, 1 / 4, 1, 1])[:, None] edge_index_true = torch.tensor( [[0, 0, 1, 1, 2, 2, 5, 6], [3, 4, 3, 4, 3, 4, 3, 4]], dtype=torch.int64) w = model.weight b = model.bias edge_features_true = torch.tensor([ [w[0, 0]], [w[1, 0]], [w[0, 1]], [w[1, 1]], [w[0, 2]], [w[1, 2]], [b[0]], [b[1]], ]) assert isinstance(graph, Data) assert graph.num_nodes == 7 assert torch.equal(graph.y, target) assert torch.equal(graph.edge_index, edge_index_true) assert torch.equal(graph.edge_features, edge_features_true) if x_true is not None: assert torch.equal(graph.x, x_true) else: assert graph.x is None
def test_incorrect_type(self): with pytest.raises(TypeError): Linear.to_graph("wrong_type")
def test_unknown_strategy(self): with pytest.raises(ValueError): Linear.to_graph(torch.nn.Linear(4, 5), node_strategy="nonexistent")