示例#1
0
    def test_transform(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )
        Graph.add_edge_attr(G, "edge_feature", edge_x)
        Graph.add_edge_attr(G, "edge_label", edge_y)
        Graph.add_node_attr(G, "node_feature", x)
        Graph.add_node_attr(G, "node_label", y)
        Graph.add_graph_attr(G, "graph_feature", graph_x)
        Graph.add_graph_attr(G, "graph_label", graph_y)

        dg = Graph(G)

        dg_edge_feature = dg.edge_feature.clone()
        dg_node_feature = dg.node_feature.clone()
        dg_graph_feature = dg.graph_feature.clone()

        dg.apply_tensor(
            lambda x: x, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(torch.all(dg_edge_feature.eq(dg.edge_feature)))
        self.assertTrue(torch.all(dg_node_feature.eq(dg.node_feature)))
        self.assertTrue(torch.all(dg_graph_feature.eq(dg.graph_feature)))

        dg.apply_tensor(
            lambda x: x + 10, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertFalse(torch.all(dg_edge_feature.eq(dg.edge_feature)))
        self.assertFalse(torch.all(dg_node_feature.eq(dg.node_feature)))
        self.assertFalse(torch.all(dg_graph_feature.eq(dg.graph_feature)))

        dg.apply_tensor(
            lambda x: x + 100, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(
            torch.all(dg.edge_feature.eq(dg_edge_feature + 10 + 100))
        )
        self.assertTrue(
            torch.all(dg.node_feature.eq(dg_node_feature + 10 + 100))
        )
        self.assertTrue(
            torch.all(dg.graph_feature.eq(dg_graph_feature + 10 + 100))
        )

        dg.apply_tensor(
            lambda x: x * 2, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(
            torch.all(dg.edge_feature.eq((dg_edge_feature + 10 + 100) * 2))
        )
        self.assertTrue(
            torch.all(dg.node_feature.eq((dg_node_feature + 10 + 100) * 2))
        )
        self.assertTrue(
            torch.all(dg.graph_feature.eq((dg_graph_feature + 10 + 100) * 2))
        )
    def test_transform(self):
        G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = (
            simple_networkx_graph()
        )
        dg = Graph(
            node_feature=x, node_label=y, edge_index=edge_index,
            edge_feature=edge_x, edge_label=edge_y,
            graph_feature=graph_x, graph_label=graph_y, directed=True
        )

        dg_edge_feature = dg.edge_feature.clone()
        dg_node_feature = dg.node_feature.clone()
        dg_graph_feature = dg.graph_feature.clone()

        dg.apply_tensor(
            lambda x: x, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(torch.all(dg_edge_feature.eq(dg.edge_feature)))
        self.assertTrue(torch.all(dg_node_feature.eq(dg.node_feature)))
        self.assertTrue(torch.all(dg_graph_feature.eq(dg.graph_feature)))

        dg.apply_tensor(
            lambda x: x + 10, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertFalse(torch.all(dg_edge_feature.eq(dg.edge_feature)))
        self.assertFalse(torch.all(dg_node_feature.eq(dg.node_feature)))
        self.assertFalse(torch.all(dg_graph_feature.eq(dg.graph_feature)))

        dg.apply_tensor(
            lambda x: x + 100, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(
            torch.all(dg.edge_feature.eq(dg_edge_feature + 10 + 100))
        )
        self.assertTrue(
            torch.all(dg.node_feature.eq(dg_node_feature + 10 + 100))
        )
        self.assertTrue(
            torch.all(dg.graph_feature.eq(dg_graph_feature + 10 + 100))
        )

        dg.apply_tensor(
            lambda x: x * 2, "edge_feature", "node_feature", "graph_feature"
        )
        self.assertTrue(
            torch.all(dg.edge_feature.eq((dg_edge_feature + 10 + 100) * 2))
        )
        self.assertTrue(
            torch.all(dg.node_feature.eq((dg_node_feature + 10 + 100) * 2))
        )
        self.assertTrue(
            torch.all(dg.graph_feature.eq((dg_graph_feature + 10 + 100) * 2))
        )