Exemplo n.º 1
0
    def test_round_trip(self, in_features, out_features, bias, random_state):
        torch.manual_seed(random_state)
        input_model = torch.nn.Linear(in_features, out_features, bias=bias)

        output_model = Linear.to_module(Linear.to_graph(input_model))

        assert torch.equal(input_model.weight, output_model.weight)
        if bias:
            assert torch.equal(input_model.bias, output_model.bias)
Exemplo n.º 2
0
    def test_same_as_linear(self, in_features, out_features, bias):
        module = torch.nn.Linear(in_features, out_features, bias=bias)
        x_linear = Linear.to_graph(module)
        x_mlp = MLP.to_graph(torch.nn.Sequential(module))

        assert x_linear.num_nodes == x_mlp.num_nodes
        assert torch.allclose(x_linear.x, x_mlp.x)
        assert torch.allclose(x_linear.edge_index, x_mlp.edge_index)
        assert torch.allclose(x_linear.edge_features, x_mlp.edge_features)
Exemplo n.º 3
0
    def test_basic_with_bias(self, node_strategy, random_state):
        in_features = 3
        out_features = 2
        target = torch.tensor([[3.43]])

        torch.manual_seed(random_state)
        model = torch.nn.Linear(in_features=in_features,
                                out_features=out_features,
                                bias=True)

        graph = Linear.to_graph(model,
                                target=target,
                                node_strategy=node_strategy)

        # Checks
        if node_strategy is None:
            x_true = None
        elif node_strategy == "constant":
            x_true = torch.tensor([1, 1, 1, 1, 1, 1, 1],
                                  dtype=torch.float)[:, None]
        elif node_strategy == "proportional":
            x_true = torch.tensor([1, 1, 1, 1 / 4, 1 / 4, 1, 1])[:, None]

        edge_index_true = torch.tensor(
            [[0, 0, 1, 1, 2, 2, 5, 6], [3, 4, 3, 4, 3, 4, 3, 4]],
            dtype=torch.int64)

        w = model.weight
        b = model.bias
        edge_features_true = torch.tensor([
            [w[0, 0]],
            [w[1, 0]],
            [w[0, 1]],
            [w[1, 1]],
            [w[0, 2]],
            [w[1, 2]],
            [b[0]],
            [b[1]],
        ])
        assert isinstance(graph, Data)

        assert graph.num_nodes == 7
        assert torch.equal(graph.y, target)
        assert torch.equal(graph.edge_index, edge_index_true)
        assert torch.equal(graph.edge_features, edge_features_true)
        if x_true is not None:
            assert torch.equal(graph.x, x_true)
        else:
            assert graph.x is None
Exemplo n.º 4
0
 def test_incorrect_type(self):
     with pytest.raises(TypeError):
         Linear.to_graph("wrong_type")
Exemplo n.º 5
0
 def test_unknown_strategy(self):
     with pytest.raises(ValueError):
         Linear.to_graph(torch.nn.Linear(4, 5), node_strategy="nonexistent")