def test_to_networkx(self, keys): kwargs = {"feature_key": "features", "global_attr_key": "data"} feature_key, global_attr_key = keys if feature_key is not None: kwargs["feature_key"] = feature_key else: del kwargs["feature_key"] if global_attr_key is not None: kwargs["global_attr_key"] = global_attr_key else: del kwargs["global_attr_key"] data = GraphData( torch.randn(10, 5), torch.randn(5, 4), torch.randn(1, 3), torch.tensor([[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]), ) g = data.to_networkx(**kwargs) assert isinstance(g, nx.OrderedMultiDiGraph) assert g.number_of_nodes() == 10 assert g.number_of_edges() == 5 fkey = kwargs.get("feature_key", "features") gkey = kwargs.get("global_attr_key", None) Comparator.data_to_nx(data, g, fkey, gkey)
def test_node_mask_entire_graph(): data1 = GraphData( torch.randn(5, 3), torch.randn(3, 3), torch.randn(1, 1), edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]), ) data2 = GraphData( torch.randn(5, 3), torch.randn(3, 3), torch.randn(1, 1), edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]), ) data3 = GraphData( torch.randn(5, 3), torch.randn(3, 3), torch.randn(1, 1), edges=torch.LongTensor([[0, 1, 2], [2, 3, 4]]), ) batch = GraphBatch.from_data_list([data1, data2, data3]) node_mask = torch.BoolTensor([True] * 5 + [False] * 5 + [True] * 5) masked = batch.apply_node_mask(node_mask) print(masked.node_idx) print(masked.edge_idx)
def test_train_shortest_path(): graphs = [ generate_shorest_path_example(100, 0.01, 1000) for _ in range(10) ] input_data = [ GraphData.from_networkx(g, feature_key="_features") for g in graphs ] target_data = [ GraphData.from_networkx(g, feature_key="_target") for g in graphs ] loader = GraphDataLoader(input_data, target_data, batch_size=32, shuffle=True) agg = lambda: Flex(MultiAggregator)(Flex.d(), ["add", "mean", "max", "min"]) network = Network() for input_batch, _ in loader: network(input_batch, 10) break loss_fn = torch.nn.BCELoss() optimizer = torch.optim.AdamW(network.parameters()) for _ in range(10): for input_batch, target_batch in loader: output = network(input_batch, 10)[0] x, y = output.x, target_batch.x loss = loss_fn(x.flatten(), y[:, 0].flatten()) loss.backward() print(loss.detach()) optimizer.step()
def test_loader_zipped(): datalist1 = [GraphData.random(5, 4, 3) for _ in range(32 * 5)] datalist2 = [GraphData.random(5, 4, 3) for _ in range(32 * 5)] loader = GraphDataLoader(datalist1, datalist2, batch_size=32, shuffle=True) for a, b in loader: assert isinstance(a, GraphBatch) assert isinstance(b, GraphBatch) assert a is not b
def test_append_edges(self): data = GraphData( torch.randn(10, 5), torch.randn(5, 4), torch.randn(1, 3), torch.randint(0, 10, torch.Size([2, 5])), ) e = torch.randn(3, 4) edges = torch.randint(0, 10, torch.Size([2, 3])) data.append_edges(e, edges)
def test_append_nodes(self): data = GraphData( torch.randn(10, 5), torch.randn(5, 4), torch.randn(1, 3), torch.randint(0, 10, torch.Size([2, 5])), ) assert data.x.shape[0] == 10 data.append_nodes(torch.randn(2, 5)) assert data.x.shape[0] == 12
def test_invalid_append_edges(self): data = GraphData( torch.randn(10, 5), torch.randn(5, 4), torch.randn(1, 3), torch.randint(0, 10, torch.Size([2, 5])), ) e = torch.randn(3, 4) edges = torch.randint(0, 10, torch.Size([2, 4])) with pytest.raises(RuntimeError): data.append_edges(e, edges)
def test_not_eq(self): args1 = ( torch.randn(20, 5), torch.randn(3, 4), torch.randn(1, 3), torch.randint(0, 10, torch.Size([2, 3])), ) args2 = (args1[0][:10], args1[1][:], args1[2][:], args1[3][:]) data1 = GraphData(*args1) data2 = GraphData(*args2) assert not data1 == data2 assert not id(data1) == id(data2)
def test_generate_shortest_path_example(): g = generate_shorest_path_example(100, 0.01, 10) for n, ndata in g.nodes(data=True): assert "source" in ndata assert "target" in ndata d1 = GraphData.from_networkx(g, feature_key="_features") d2 = GraphData.from_networkx(g, feature_key="_target") assert tuple(d1.shape) == (4, 1, 1) assert tuple(d2.shape) == (2, 2, 1)
def test_mask_all_nodes(): deterministic_seed(0) data = GraphData( torch.randn(5, 5), torch.randn(3, 2), torch.randn(1, 1), edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]), ) node_mask = torch.BoolTensor([False, False, False, False, False]) data2 = data.apply_node_mask(node_mask) assert data2.num_nodes == 0 assert data2.num_edges == 0
def test_mask_no_edges(): deterministic_seed(0) data = GraphData( torch.randn(5, 5), torch.randn(3, 2), torch.randn(1, 1), edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]), ) edge_mask = torch.BoolTensor([True, True, True]) data2 = data.apply_edge_mask(edge_mask) assert data2.num_nodes == 5 assert data2.num_edges == 3
def test_from_networkx_no_edge(self, keys): kwargs = {"feature_key": "features", "global_attr_key": "data"} feature_key, global_attr_key = keys if feature_key is not None: kwargs["feature_key"] = feature_key else: del kwargs["feature_key"] if global_attr_key is not None: kwargs["global_attr_key"] = global_attr_key else: del kwargs["global_attr_key"] fkey = kwargs.get("feature_key", "features") gkey = kwargs.get("global_attr_key", None) g = nx.OrderedMultiDiGraph() g.add_node("node1", **{fkey: np.random.randn(5)}) g.add_node("node2", **{fkey: np.random.randn(5)}) g.ordered_edges = [] # g.add_edge('node1', 'node2', **{fkey: torch.randn(4)}) g.set_global({fkey: np.random.randn(3)}, gkey) data = GraphData.from_networkx(g, **kwargs) Comparator.data_to_nx(data, g, fkey, gkey)
def data(request): deterministic_seed(0) data_cls = request.param if data_cls is GraphData: return GraphData.random(5, 4, 3) else: return GraphBatch.random_batch(10, 5, 4, 3)
def run(): data = GraphData.random(5, 4, 3, min_edges=1000, min_nodes=1000) ij = torch.cat([data.edges]) sparse_mask = scatter_coo(ij, 1, expand=True, size=(data.num_nodes, data.num_nodes))
def data(request): args = (5, 4, 3) kwargs = dict(min_nodes=10, max_nodes=10, min_edges=5, max_edges=5) if request.param is GraphData: return GraphData.random(*args, **kwargs) else: return GraphBatch.random_batch(100, *args, **kwargs)
def test_fully_connected_singe_graph_batch(): deterministic_seed(0) data = GraphData.random(5, 4, 3) batch = GraphBatch.from_data_list([data]) t = FullyConnected() batch2 = t(batch) assert batch2.edges.shape[1] > batch.edges.shape[1]
def test_view_graph_data(self, slices): data = GraphData.random(5, 5, 5, min_nodes=10, max_nodes=10) assert data.shape == (5, 5, 5) data_view = data.view(*slices) print(data_view.shape) # assert data_view.shape == (5, 3, 5) assert data_view.share_storage(data)
def test_invalid_n_edges(self): with pytest.raises(RuntimeError): GraphData( torch.randn(10, 5), torch.randn(5, 4), torch.randn(1, 3), torch.randint(0, 10, torch.Size([3, 5])), )
def test_loader_first(): datalist = [GraphData.random(5, 4, 3) for _ in range(32 * 5)] loader = GraphDataLoader(datalist, batch_size=32, shuffle=True) batch = loader.first() assert isinstance(batch, GraphBatch) assert batch.shape == (5, 4, 3) assert batch.num_graphs == 32
def test_loader_dataset(): datalist = [GraphData.random(5, 4, 3) for _ in range(32 * 4)] dataset = GraphDataset(datalist) for batch in GraphDataLoader(dataset, shuffle=True, batch_size=32): print(batch.size) assert isinstance(batch, GraphBatch) assert batch.size[-1] == 32
def test_floyd_warshall_neighbors(self, nodes, depth, return_matrix): data = GraphData.random(5, 4, 3, min_edges=300, min_nodes=100) print(data.density()) x = floyd_warshall_neighbors(data, nodes, depth, return_matrix=return_matrix) print(x)
def test_(): data1 = GraphData( torch.randn(1, 5), torch.randn(1, 4), torch.randn(1, 3), torch.LongTensor([[0], [0]]), ) data2 = GraphData( torch.randn(1, 5), torch.randn(0, 4), torch.randn(1, 3), torch.LongTensor([[], []]), ) batch = GraphBatch.from_data_list([data1, data2])
def test_invalid_batch(self, offsets): data1 = GraphData( torch.randn(10, 10), torch.randn(3, 4), torch.randn(1, 3), torch.randint(0, 10, torch.Size([2, 3])), ) data2 = GraphData( torch.randn(12, 10 + offsets[0]), torch.randn(4, 4 + offsets[1]), torch.randn(1, 3 + offsets[2]), torch.randint(0, 10, torch.Size([2, 4])), ) with pytest.raises(RuntimeError): GraphBatch.from_data_list([data1, data2])
def test_invalid_global_ndims(self): with pytest.raises(RuntimeError): GraphData( torch.randn(10, 5), torch.randn(5, 4), torch.randn(1), torch.randint(0, 10, torch.Size([2, 5])), )
def test_invalid_number_of_nodes(self): with pytest.raises(RuntimeError): GraphData( torch.randn(10, 5), torch.randn(5, 4), torch.randn(1, 3), torch.randint(11, 12, torch.Size([2, 6])), )
def test_invalid_global_shape(self): with pytest.raises(RuntimeError): GraphData( torch.randn(10, 5), torch.randn(5, 4), torch.randn(3), torch.randint(11, 12, torch.Size([2, 6])), )
def sigmoid_circuit(cls, data_size, batch_size): import math def func(x): return 1 - 1.0 / (1 + math.exp(-x)) input_data = [] output_data = [] for _ in range(data_size): n_size = np.random.randint(2, 20) tree = nx.random_tree(n_size) # randomize node directions g = nx.DiGraph() for n1, n2, edata in tree.edges(data=True): i = np.random.randint(2) if i % 2 == 0: g.add_edge(n1, n2) else: g.add_edge(n2, n1) cls._default_g(g) for n in nx_utils.iter_roots(g): ndata = g.nodes[n] ndata["target"] = np.array(10.0) for n in nx.topological_sort(g): ndata = g.nodes[n] if "target" not in ndata: incoming = [] for p in g.predecessors(n): pdata = g.nodes[p] incoming.append(pdata["target"]) incoming = np.concatenate(incoming) i = incoming.sum() o = func(i) ndata["target"] = o input_data.append( GraphData.from_networkx(g, feature_key="features")) output_data.append(GraphData.from_networkx(g, feature_key="target")) return GraphDataLoader(list(zip(input_data, output_data)), batch_size=batch_size)
def test_mask_one_edges(): deterministic_seed(0) edges = torch.LongTensor([[0, 0, 0], [1, 2, 3]]) expected_edges = torch.LongTensor([[0, 0], [1, 3]]) e = torch.randn(3, 2) edge_mask = torch.BoolTensor([True, False, True]) eidx = torch.where(edge_mask) expected_e = e[eidx] data = GraphData(torch.randn(5, 5), e, torch.randn(1, 1), edges=edges) data2 = data.apply_edge_mask(edge_mask) assert torch.all(data2.edges == expected_edges) assert torch.all(data2.e == expected_e) assert torch.all(data2.g == data.g) assert torch.all(data2.x == data.x)
def test_from_networkx_missing_node_data(self): g = nx.DiGraph() g.add_node(1) g.add_edge(1, 2, feature=np.array([10.0])) g.set_global({"feature": np.array([12.0])}) data = GraphData.from_networkx(g, feature_key="feature") assert data.x.shape == (2, 0) assert torch.all(data.e == torch.tensor([[10.0]])) assert torch.all(data.g == torch.tensor([[12.0]]))
def test_from_networkx_missing_glob_data(self): g = nx.DiGraph() g.add_node(1, feature=np.array([10.0])) g.add_node(2, feature=np.array([11.0])) g.add_edge(1, 2, feature=np.array([12.0])) data = GraphData.from_networkx(g, feature_key="feature") assert torch.all(data.x == torch.tensor([[10.0], [11.0]])) assert torch.all(data.e == torch.tensor([[12.0]])) assert data.g.shape == (1, 0)