def test_unknown_strategy(self): with pytest.raises(ValueError): MLP.to_graph( torch.nn.Sequential(torch.nn.Linear(4, 5), torch.nn.Linear(6, 7)), node_strategy="nonexistent", )
def test_round_trip(self, module): module_rt = MLP.to_module(MLP.to_graph(module)) layers = list(m for m in module.modules() if m != module) layers_rt = list(m for m in module_rt.modules() if m != module_rt) assert len(layers) == len(layers_rt) for linear, linear_rt in zip(layers, layers_rt): assert torch.allclose(linear.weight, linear_rt.weight) if linear.bias is not None: assert torch.allclose(linear.bias, linear_rt.bias)
def test_n_edges(self, module): """Number of edges is equal to the number of parameteres.""" n_edges_expected = sum(p.numel() for p in module.parameters() if p.requires_grad) graph = MLP.to_graph(module) assert len(graph.edge_features) == n_edges_expected assert graph.edge_index.shape[1] == n_edges_expected
def test_same_as_linear(self, in_features, out_features, bias): module = torch.nn.Linear(in_features, out_features, bias=bias) x_linear = Linear.to_graph(module) x_mlp = MLP.to_graph(torch.nn.Sequential(module)) assert x_linear.num_nodes == x_mlp.num_nodes assert torch.allclose(x_linear.x, x_mlp.x) assert torch.allclose(x_linear.edge_index, x_mlp.edge_index) assert torch.allclose(x_linear.edge_features, x_mlp.edge_features)
def test_biases_correct(self): """Bias has not incoming edgese and has exactly one outcoming.""" module = torch.nn.Sequential( torch.nn.Linear(2, 3, bias=True), torch.nn.Linear(3, 4, bias=True), torch.nn.Linear(4, 2, bias=True), ) is_bias = [ False, False, False, False, False, True, True, True, False, False, False, False, True, True, True, True, False, False, True, True, ] bias_ids = {i for i, b in enumerate(is_bias) if b} graph = MLP.to_graph(module) start_nodes = graph.edge_index[0, :].detach().numpy() end_nodes = graph.edge_index[1, :].detach().numpy() # There are no incoming edges to bias nodes assert not (bias_ids & set(end_nodes)) # Assert the bias nodes are outcoming for exactly 1 edge for bias_id in bias_ids: assert len([x for x in list(start_nodes) if x == bias_id]) == 1
def test_basic_with_bias(self, node_strategy, random_state): in_features = 3 hidden_features = 4 out_features = 2 target = torch.tensor([[3.43]]) torch.manual_seed(random_state) linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_features, bias=True) linear_2 = torch.nn.Linear(in_features=hidden_features, out_features=out_features, bias=True) module = torch.nn.Sequential(linear_1, linear_2) graph = MLP.to_graph(module, target=target, node_strategy=node_strategy) # Checks if node_strategy is None: x_true = None elif node_strategy == "constant": x_true = torch.ones((15, 1), dtype=torch.float) elif node_strategy == "proportional": x_true = torch.tensor([ 1, 1, 1, 1 / 4, 1 / 4, 1 / 4, 1 / 4, 1, 1, 1, 1, 1 / 5, 1 / 5, 1, 1, ])[:, None] edge_index_true = torch.tensor( [ [ 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 7, 8, 9, 10, 3, 3, 4, 4, 5, 5, 6, 6, 13, 14, ], [ 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, 11, 12, 11, 12, 11, 12, 11, 12, 11, 12, ], ], dtype=torch.int64, ) w_1 = linear_1.weight b_1 = linear_1.bias w_2 = linear_2.weight b_2 = linear_2.bias edge_features_true = torch.tensor([ [w_1[0, 0]], [w_1[1, 0]], [w_1[2, 0]], [w_1[3, 0]], [w_1[0, 1]], [w_1[1, 1]], [w_1[2, 1]], [w_1[3, 1]], [w_1[0, 2]], [w_1[1, 2]], [w_1[2, 2]], [w_1[3, 2]], [b_1[0]], [b_1[1]], [b_1[2]], [b_1[3]], [w_2[0, 0]], [w_2[1, 0]], [w_2[0, 1]], [w_2[1, 1]], [w_2[0, 2]], [w_2[1, 2]], [w_2[0, 3]], [w_2[1, 3]], [b_2[0]], [b_2[1]], ]) assert isinstance(graph, Data) assert graph.num_nodes == 15 assert torch.equal(graph.y, target) assert torch.equal(graph.edge_index, edge_index_true) assert torch.equal(graph.edge_features, edge_features_true) if x_true is not None: assert torch.equal(graph.x, x_true) else: assert graph.x is None
def test_incorrect_type(self): with pytest.raises(TypeError): MLP.to_graph("wrong_type")
def test_is_mlp(self, module, out): assert MLP._is_mlp(module) == out