def test_transform(self): G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = ( simple_networkx_graph() ) Graph.add_edge_attr(G, "edge_feature", edge_x) Graph.add_edge_attr(G, "edge_label", edge_y) Graph.add_node_attr(G, "node_feature", x) Graph.add_node_attr(G, "node_label", y) Graph.add_graph_attr(G, "graph_feature", graph_x) Graph.add_graph_attr(G, "graph_label", graph_y) dg = Graph(G) dg_edge_feature = dg.edge_feature.clone() dg_node_feature = dg.node_feature.clone() dg_graph_feature = dg.graph_feature.clone() dg.apply_tensor( lambda x: x, "edge_feature", "node_feature", "graph_feature" ) self.assertTrue(torch.all(dg_edge_feature.eq(dg.edge_feature))) self.assertTrue(torch.all(dg_node_feature.eq(dg.node_feature))) self.assertTrue(torch.all(dg_graph_feature.eq(dg.graph_feature))) dg.apply_tensor( lambda x: x + 10, "edge_feature", "node_feature", "graph_feature" ) self.assertFalse(torch.all(dg_edge_feature.eq(dg.edge_feature))) self.assertFalse(torch.all(dg_node_feature.eq(dg.node_feature))) self.assertFalse(torch.all(dg_graph_feature.eq(dg.graph_feature))) dg.apply_tensor( lambda x: x + 100, "edge_feature", "node_feature", "graph_feature" ) self.assertTrue( torch.all(dg.edge_feature.eq(dg_edge_feature + 10 + 100)) ) self.assertTrue( torch.all(dg.node_feature.eq(dg_node_feature + 10 + 100)) ) self.assertTrue( torch.all(dg.graph_feature.eq(dg_graph_feature + 10 + 100)) ) dg.apply_tensor( lambda x: x * 2, "edge_feature", "node_feature", "graph_feature" ) self.assertTrue( torch.all(dg.edge_feature.eq((dg_edge_feature + 10 + 100) * 2)) ) self.assertTrue( torch.all(dg.node_feature.eq((dg_node_feature + 10 + 100) * 2)) ) self.assertTrue( torch.all(dg.graph_feature.eq((dg_graph_feature + 10 + 100) * 2)) )
def test_transform(self): G, x, y, edge_x, edge_y, edge_index, graph_x, graph_y = ( simple_networkx_graph() ) dg = Graph( node_feature=x, node_label=y, edge_index=edge_index, edge_feature=edge_x, edge_label=edge_y, graph_feature=graph_x, graph_label=graph_y, directed=True ) dg_edge_feature = dg.edge_feature.clone() dg_node_feature = dg.node_feature.clone() dg_graph_feature = dg.graph_feature.clone() dg.apply_tensor( lambda x: x, "edge_feature", "node_feature", "graph_feature" ) self.assertTrue(torch.all(dg_edge_feature.eq(dg.edge_feature))) self.assertTrue(torch.all(dg_node_feature.eq(dg.node_feature))) self.assertTrue(torch.all(dg_graph_feature.eq(dg.graph_feature))) dg.apply_tensor( lambda x: x + 10, "edge_feature", "node_feature", "graph_feature" ) self.assertFalse(torch.all(dg_edge_feature.eq(dg.edge_feature))) self.assertFalse(torch.all(dg_node_feature.eq(dg.node_feature))) self.assertFalse(torch.all(dg_graph_feature.eq(dg.graph_feature))) dg.apply_tensor( lambda x: x + 100, "edge_feature", "node_feature", "graph_feature" ) self.assertTrue( torch.all(dg.edge_feature.eq(dg_edge_feature + 10 + 100)) ) self.assertTrue( torch.all(dg.node_feature.eq(dg_node_feature + 10 + 100)) ) self.assertTrue( torch.all(dg.graph_feature.eq(dg_graph_feature + 10 + 100)) ) dg.apply_tensor( lambda x: x * 2, "edge_feature", "node_feature", "graph_feature" ) self.assertTrue( torch.all(dg.edge_feature.eq((dg_edge_feature + 10 + 100) * 2)) ) self.assertTrue( torch.all(dg.node_feature.eq((dg_node_feature + 10 + 100) * 2)) ) self.assertTrue( torch.all(dg.graph_feature.eq((dg_graph_feature + 10 + 100) * 2)) )