Ejemplo n.º 1
0
 def test_unknown_strategy(self):
     with pytest.raises(ValueError):
         MLP.to_graph(
             torch.nn.Sequential(torch.nn.Linear(4, 5),
                                 torch.nn.Linear(6, 7)),
             node_strategy="nonexistent",
         )
Ejemplo n.º 2
0
    def test_round_trip(self, module):
        module_rt = MLP.to_module(MLP.to_graph(module))

        layers = list(m for m in module.modules() if m != module)
        layers_rt = list(m for m in module_rt.modules() if m != module_rt)

        assert len(layers) == len(layers_rt)

        for linear, linear_rt in zip(layers, layers_rt):
            assert torch.allclose(linear.weight, linear_rt.weight)
            if linear.bias is not None:
                assert torch.allclose(linear.bias, linear_rt.bias)
Ejemplo n.º 3
0
    def test_n_edges(self, module):
        """Number of edges is equal to the number of parameteres."""
        n_edges_expected = sum(p.numel() for p in module.parameters()
                               if p.requires_grad)
        graph = MLP.to_graph(module)

        assert len(graph.edge_features) == n_edges_expected
        assert graph.edge_index.shape[1] == n_edges_expected
Ejemplo n.º 4
0
    def test_same_as_linear(self, in_features, out_features, bias):
        module = torch.nn.Linear(in_features, out_features, bias=bias)
        x_linear = Linear.to_graph(module)
        x_mlp = MLP.to_graph(torch.nn.Sequential(module))

        assert x_linear.num_nodes == x_mlp.num_nodes
        assert torch.allclose(x_linear.x, x_mlp.x)
        assert torch.allclose(x_linear.edge_index, x_mlp.edge_index)
        assert torch.allclose(x_linear.edge_features, x_mlp.edge_features)
Ejemplo n.º 5
0
    def test_biases_correct(self):
        """Bias has not incoming edgese and has exactly one outcoming."""
        module = torch.nn.Sequential(
            torch.nn.Linear(2, 3, bias=True),
            torch.nn.Linear(3, 4, bias=True),
            torch.nn.Linear(4, 2, bias=True),
        )

        is_bias = [
            False,
            False,
            False,
            False,
            False,
            True,
            True,
            True,
            False,
            False,
            False,
            False,
            True,
            True,
            True,
            True,
            False,
            False,
            True,
            True,
        ]
        bias_ids = {i for i, b in enumerate(is_bias) if b}

        graph = MLP.to_graph(module)
        start_nodes = graph.edge_index[0, :].detach().numpy()
        end_nodes = graph.edge_index[1, :].detach().numpy()

        # There are no incoming edges to bias nodes
        assert not (bias_ids & set(end_nodes))

        # Assert the bias nodes are outcoming for exactly 1 edge
        for bias_id in bias_ids:
            assert len([x for x in list(start_nodes) if x == bias_id]) == 1
Ejemplo n.º 6
0
    def test_basic_with_bias(self, node_strategy, random_state):
        in_features = 3
        hidden_features = 4
        out_features = 2

        target = torch.tensor([[3.43]])

        torch.manual_seed(random_state)

        linear_1 = torch.nn.Linear(in_features=in_features,
                                   out_features=hidden_features,
                                   bias=True)
        linear_2 = torch.nn.Linear(in_features=hidden_features,
                                   out_features=out_features,
                                   bias=True)
        module = torch.nn.Sequential(linear_1, linear_2)
        graph = MLP.to_graph(module,
                             target=target,
                             node_strategy=node_strategy)

        # Checks
        if node_strategy is None:
            x_true = None
        elif node_strategy == "constant":
            x_true = torch.ones((15, 1), dtype=torch.float)
        elif node_strategy == "proportional":
            x_true = torch.tensor([
                1,
                1,
                1,
                1 / 4,
                1 / 4,
                1 / 4,
                1 / 4,
                1,
                1,
                1,
                1,
                1 / 5,
                1 / 5,
                1,
                1,
            ])[:, None]

        edge_index_true = torch.tensor(
            [
                [
                    0,
                    0,
                    0,
                    0,
                    1,
                    1,
                    1,
                    1,
                    2,
                    2,
                    2,
                    2,
                    7,
                    8,
                    9,
                    10,
                    3,
                    3,
                    4,
                    4,
                    5,
                    5,
                    6,
                    6,
                    13,
                    14,
                ],
                [
                    3,
                    4,
                    5,
                    6,
                    3,
                    4,
                    5,
                    6,
                    3,
                    4,
                    5,
                    6,
                    3,
                    4,
                    5,
                    6,
                    11,
                    12,
                    11,
                    12,
                    11,
                    12,
                    11,
                    12,
                    11,
                    12,
                ],
            ],
            dtype=torch.int64,
        )

        w_1 = linear_1.weight
        b_1 = linear_1.bias
        w_2 = linear_2.weight
        b_2 = linear_2.bias
        edge_features_true = torch.tensor([
            [w_1[0, 0]],
            [w_1[1, 0]],
            [w_1[2, 0]],
            [w_1[3, 0]],
            [w_1[0, 1]],
            [w_1[1, 1]],
            [w_1[2, 1]],
            [w_1[3, 1]],
            [w_1[0, 2]],
            [w_1[1, 2]],
            [w_1[2, 2]],
            [w_1[3, 2]],
            [b_1[0]],
            [b_1[1]],
            [b_1[2]],
            [b_1[3]],
            [w_2[0, 0]],
            [w_2[1, 0]],
            [w_2[0, 1]],
            [w_2[1, 1]],
            [w_2[0, 2]],
            [w_2[1, 2]],
            [w_2[0, 3]],
            [w_2[1, 3]],
            [b_2[0]],
            [b_2[1]],
        ])
        assert isinstance(graph, Data)

        assert graph.num_nodes == 15
        assert torch.equal(graph.y, target)
        assert torch.equal(graph.edge_index, edge_index_true)
        assert torch.equal(graph.edge_features, edge_features_true)
        if x_true is not None:
            assert torch.equal(graph.x, x_true)
        else:
            assert graph.x is None
Ejemplo n.º 7
0
 def test_incorrect_type(self):
     with pytest.raises(TypeError):
         MLP.to_graph("wrong_type")
Ejemplo n.º 8
0
 def test_is_mlp(self, module, out):
     assert MLP._is_mlp(module) == out