Exemple #1
0
    def test_with_bias(self, random_state):
        torch.manual_seed(random_state)
        # 2 * 4 | in_features x out_features

        w = torch.randn(4, 2)
        b = torch.randn(4)

        graph = Data(
            edge_index=torch.tensor([
                [0, 0, 0, 0, 1, 1, 1, 1, 6, 7, 8, 9],
                [2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5],
            ]),
            edge_features=torch.tensor([
                [w[0, 0]],
                [w[1, 0]],
                [w[2, 0]],
                [w[3, 0]],
                [w[0, 1]],
                [w[1, 1]],
                [w[2, 1]],
                [w[3, 1]],
                [b[0]],
                [b[1]],
                [b[2]],
                [b[3]],
            ]),
        )
        model = Linear.to_module(graph)
        model_true = torch.nn.Linear(2, 4, bias=True)
        model_true.weight = torch.nn.Parameter(w)
        model_true.bias = torch.nn.Parameter(b)

        assert isinstance(model_true, torch.nn.Linear)
        assert torch.equal(model_true.weight, w)
        assert torch.equal(model.bias, b)
Exemple #2
0
    def test_without_bias(self, random_state):
        torch.manual_seed(random_state)
        # 2 * 4 | in_features x out_features

        w = torch.randn(4, 2)
        graph = Data(
            edge_index=torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1],
                                     [2, 3, 4, 5, 2, 3, 4, 5]]),
            edge_features=torch.tensor([
                [w[0, 0]],
                [w[1, 0]],
                [w[2, 0]],
                [w[3, 0]],
                [w[0, 1]],
                [w[1, 1]],
                [w[2, 1]],
                [w[3, 1]],
            ]),
        )
        model = Linear.to_module(graph)

        model_true = torch.nn.Linear(2, 4, bias=False)
        model_true.weight = torch.nn.Parameter(w)

        assert isinstance(model_true, torch.nn.Linear)
        assert torch.equal(model_true.weight, model.weight)
        assert model.bias is None
Exemple #3
0
    def test_round_trip(self, in_features, out_features, bias, random_state):
        torch.manual_seed(random_state)
        input_model = torch.nn.Linear(in_features, out_features, bias=bias)

        output_model = Linear.to_module(Linear.to_graph(input_model))

        assert torch.equal(input_model.weight, output_model.weight)
        if bias:
            assert torch.equal(input_model.bias, output_model.bias)