def create_feed_dict(rand, batch_size, traj_idx_min_max_tr, static_graph, trajectory, input_ph, target_ph): input_graphs_dicts, target_graphs_dicts = generate_graphs_dicts( rand, batch_size, traj_idx_min_max_tr, static_graph, trajectory) input_graphs = utils_np.data_dicts_to_graphs_tuple(input_graphs_dicts) target_graphs = utils_np.data_dicts_to_graphs_tuple(target_graphs_dicts) feed_dict = {input_ph: input_graphs, target_ph: target_graphs} return feed_dict
def make_feed_dict(self, seqIn, patternIn, target=None): seqIn = utils_np.data_dicts_to_graphs_tuple(seqIn) patternIn = utils_np.data_dicts_to_graphs_tuple(patternIn) feed_dict = utils_tf.get_feed_dict( self.seq_input_ph, seqIn ) feed_dict.update( utils_tf.get_feed_dict( self.pattern_input_ph, patternIn ) ) if not target == None: target = utils_np.data_dicts_to_graphs_tuple(target) feed_dict.update( utils_tf.get_feed_dict( self.target_ph, target ) ) return feed_dict
def create_feed_dict(generator, batch_size, input_ph, target_ph, is_trained=True): inputs, targets = generator(batch_size, is_trained) if isinstance(inputs[0], dict): input_graphs = utils_np.data_dicts_to_graphs_tuple(inputs) target_graphs = utils_np.data_dicts_to_graphs_tuple(targets) else: input_graphs = utils_np.networkxs_to_graphs_tuple(inputs) target_graphs = utils_np.networkxs_to_graphs_tuple(targets) feed_dict = {input_ph: input_graphs, target_ph: target_graphs} return feed_dict
def generate_graph(file_path, batch_size, keep_features, existence_as_vector=True): """ The function extracts batch_size amount of graph from the given file :param file_path: The path to the graph json file :param batch_size: The number of graph in each batch. :param keep_features: Whether to keep all features of the graph. It is advised to do so in case of input graphs. :param existence_as_vector: Whether to represent existence feature as a vector with the length of 2 in an individual feature vector. It should be False when processing the input graphs. :yield: GraphTuples of the given size """ graph_dicts = [] with open(file_path) as json_file: while True: line = json_file.readline().strip() while len(graph_dicts) != batch_size: if line == "" or line is None: json_file = open(file_path) line = json_file.readline().strip() json_dict = json.loads(line) json_dict = process_line(json_dict, keep_features, existence_as_vector) is_valid_graph(json_dict) graph_dicts.append(json_dict) line = json_file.readline().strip() graphs_tuple = utils_np.data_dicts_to_graphs_tuple(graph_dicts) graph_dicts = [] yield graphs_tuple
def test_data_dicts_to_graphs_tuple_infer_n_node(self): """Not having nodes is fine if providing the number of nodes.""" for graph_dict in self.graphs_dicts_in: graph_dict["n_node"] = graph_dict["nodes"].shape[0] graph_dict["nodes"] = None out = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) self.assertAllEqual([0, 1, 1, 1, 2, 2, 2], out.n_node)
def test_data_dicts_to_graphs_tuple_from_lists(self): """Tests creatings a GraphsTuple from python lists.""" for graph_dict in self.graphs_dicts_in: graph_dict["receivers"] = graph_dict["receivers"].tolist() graph_dict["senders"] = graph_dict["senders"].tolist() graphs = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) self._assert_graph_equals_np(self.reference_graph, graphs)
def cons_graph(self, obs, poss, chs=None, act=None): mini = np.zeros((52, 52)) for i in range(len(chs)): mini[int(poss[chs[i]][0]) + 6, int(poss[chs[i]][1]) + 6] = 1 tmp = [] for i in range(len(poss)): tmp.append(mini[int(poss[i][0]):int(poss[i][0]) + 13, int(poss[i][1]):int(poss[i][1] + 13)]) tmp = np.reshape(tmp, (len(obs), 13, 13, 1)) obs = np.concatenate((obs, tmp), axis=-1) datadict = {} datadict["obs"] = obs datadict["senders"] = [] datadict["receivers"] = [] datadict["edges"] = [] datadict["lsenders"] = [] datadict["lreceivers"] = [] datadict["ledges"] = [] datadict["hedges"] = [] datadict["hsenders"] = [] datadict["hreceivers"] = [] indexs = self.get_indexs(poss) datadict['nodes'] = [[0 for j in range(256)] for i in range(len(obs))] datadict['lnodes'] = [[0 for j in range(32)] for i in range(len(obs))] datadict['hnodes'] = [[0 for j in range(64)] for i in range(len(obs))] datadict["globals"] = [0, 0, 0] datadict['q'] = [[0 for i in range(self.num_actions)] for i in range(len(obs))] chs = list(set(chs)) if chs is not None: for ch in chs: for ind in np.where(indexs[ch] == 0)[0]: datadict["senders"].append(ind) datadict["receivers"].append(ch) datadict["lsenders"].append(ch) datadict["lreceivers"].append(ind) datadict["ledges"].append([0 for i in range(9)]) datadict["edges"].append([0 for i in range(9)]) for ch in chs: for c in chs: if ch != c: datadict["hedges"].append([0 for i in range(9)]) datadict["hsenders"].append(ch) datadict["hreceivers"].append(c) if len(datadict['edges']) == 0: datadict["edges"] = np.zeros(shape=(0, 9)) if len(datadict["hedges"]) == 0: datadict["hedges"] = np.zeros(shape=(0, 9)) if len(datadict["ledges"]) == 0: datadict["ledges"] = np.zeros(shape=(0, 9)) if act is not None: datadict["act"] = act else: datadict["act"] = [0 for i in range(len(obs))] datadicts = [datadict] graphtuple = utils_np.data_dicts_to_graphs_tuple(datadicts) return graphtuple
def test_graphs_tuple_to_networkxs(self, none_fields): if "nodes" in none_fields: for graph in self.graphs_dicts_in: graph["n_node"] = graph["nodes"].shape[0] graphs = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs = graphs.map(lambda _: None, none_fields) graph_nxs = utils_np.graphs_tuple_to_networkxs(graphs) for data_dict, graph_nx in zip(self.graphs_dicts_out, graph_nxs): if "globals" in none_fields: self.assertEqual(None, graph_nx.graph["features"]) else: self.assertAllClose(data_dict["globals"], graph_nx.graph["features"]) nodes_data = graph_nx.nodes(data=True) for i, (v, (j, n)) in enumerate(zip(data_dict["nodes"], nodes_data)): self.assertEqual(i, j) if "nodes" in none_fields: self.assertEqual(None, n["features"]) else: self.assertAllClose(v, n["features"]) edges_data = sorted(graph_nx.edges(data=True), key=lambda x: x[2]["index"]) for v, (_, _, e) in zip(data_dict["edges"], edges_data): if "edges" in none_fields: self.assertEqual(None, e["features"]) else: self.assertAllClose(v, e["features"]) for r, s, (i, j, _) in zip(data_dict["receivers"], data_dict["senders"], edges_data): self.assertEqual(s, i) self.assertEqual(r, j)
def test_none_throws_error(self, none_field): """Tests that an error is thrown if a GraphsTuple field is None.""" graphs_tuple = utils_np.data_dicts_to_graphs_tuple([self.graphs_dicts[1]]) graphs_tuple = graphs_tuple.replace(**{none_field: None}) with self.assertRaisesRegex( ValueError, "`{}` was `None`. All fields of the `G".format(none_field)): utils_tf.specs_from_graphs_tuple(graphs_tuple)
def graphs_tuple_loads(string_dump): data_dicts = json.loads(string_dump) for data_dict in data_dicts: for key in data_dict: data_dict[key] = np.array(data_dict[key]) graphs_tuple = utils_np.data_dicts_to_graphs_tuple(data_dicts) return graphs_tuple
def test_graphs_tuple_to_data_dicts(self, none_fields): graphs_tuple = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = graphs_tuple.map(lambda _: None, none_fields) data_dicts = utils_np.graphs_tuple_to_data_dicts(graphs_tuple) for none_field, data_dict in itertools.product(none_fields, data_dicts): self.assertEqual(None, data_dict[none_field]) for expected_data_dict, data_dict in zip(self.graphs_dicts_out, data_dicts): for k, v in expected_data_dict.items(): if k not in none_fields: self.assertAllClose(v, data_dict[k])
def make_feed_dict(val): if isinstance(val, GraphsTuple): graphs_tuple = val else: dicts = [] for graphs_tuple in val: dicts.append( utils_np.graphs_tuple_to_data_dicts(graphs_tuple)[0]) graphs_tuple = utils_np.data_dicts_to_graphs_tuple(dicts) return utils_tf.get_feed_dict(placeholders, graphs_tuple)
def test_get_single_item(self): index = 2 expected = self.graphs_dicts_out[index] graphs = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graph = utils_np.get_graph(graphs, index) actual, = utils_np.graphs_tuple_to_data_dicts(graph) for k, v in expected.items(): self.assertAllClose(v, actual[k])
def test_get_many_items(self): index = slice(1, 3) expected = self.graphs_dicts_out[index] graphs = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs2 = utils_np.get_graph(graphs, index) actual = utils_np.graphs_tuple_to_data_dicts(graphs2) for ex, ac in zip(expected, actual): for k, v in ex.items(): self.assertAllClose(v, ac[k])
def graphTupleValEqual(self, feat, val0, val1): if isinstance(val0, graphs.GraphsTuple) and isinstance(val1, dict): val0 = utils_np.graphs_tuple_to_data_dicts(val0)[0] else: val0 = utils_np.data_dicts_to_graphs_tuple([val0]) val0 = utils_np.graphs_tuple_to_data_dicts(val0)[0] if isinstance(val1, graphs.GraphsTuple): val1 = utils_np.graphs_tuple_to_data_dicts(val1)[0] else: val1 = utils_np.data_dicts_to_graphs_tuple([val1]) val1 = utils_np.graphs_tuple_to_data_dicts(val1)[0] for k in val0.keys(): sub_feat = feat.features[k] if val0[k] is None or val1[k] is None: val0_none = val0[k] is None or val0[k].shape[0] == 0 val1_none = val1[k] is None or val1[k].shape[0] == 0 if val0_none != val1_none: return False elif not self.isEqualSampleVals(sub_feat, val0[k], val1[k]): return False return True
def test_networkxs_to_graphs_tuple(self): graph0 = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graph_nxs = [] for data_dict in self.graphs_dicts_in: graph_nxs.append(_single_data_dict_to_networkx(data_dict)) hints = { "edge_shape_hint": data_dict["edges"].shape[1:], "node_shape_hint": data_dict["nodes"].shape[1:], "data_type_hint": data_dict["nodes"].dtype, } graph = utils_np.networkxs_to_graphs_tuple(graph_nxs, **hints) self._assert_graph_equals_np(graph0, graph, force_edges_ordering=True)
def test_dynamic_batch_sizes(self, block_constructor): """Checks that all batch sizes are as expected through a GraphNetwork.""" # Remove all placeholders from here, these are unnecessary in tf2. input_graph = utils_np.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2]) input_graph = input_graph.map(tf.constant, fields=graphs.ALL_FIELDS) model = block_constructor( functools.partial(snt.nets.MLP, output_sizes=[10])) output = model(input_graph) actual = utils_tf.nest_to_numpy(output) for k, v in input_graph._asdict().items(): self.assertEqual(v.shape[0], getattr(actual, k).shape[0])
def test_dynamic_batch_sizes(self): """Checks that all batch sizes are as expected through a GraphNetwork.""" input_graph = self._get_input_graph() placeholders = input_graph.map(_mask_leading_dimension, graphs.ALL_FIELDS) model = self._get_model() output = model(placeholders) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) other_input_graph = utils_np.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2]) actual = sess.run(output, {placeholders: other_input_graph}) for k, v in other_input_graph._asdict().items(): self.assertEqual(v.shape[0], getattr(actual, k).shape[0])
def test_dynamic_batch_sizes(self, block_constructor): """Checks that all batch sizes are as expected through a GraphNetwork.""" input_graph = self._get_input_graph() placeholders = input_graph.map(lambda field: field.unsqueeze(0), graphs.ALL_FIELDS) model = block_constructor(functools.partial( MLP, output_sizes=[10])) # TODO: change output = model(placeholders) other_input_graph = utils_np.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2]) for k, v in other_input_graph._asdict().items(): self.assertEqual(v.shape[0], getattr(output, k).shape[0])
def test_data_dicts_to_graphs_tuple_cast_types(self): """Index and number fields should be cast to numpy arrays.""" for graph_dict in self.graphs_dicts_in: graph_dict["n_node"] = np.array(graph_dict["nodes"].shape[0], dtype=np.int64) graph_dict["receivers"] = graph_dict["receivers"].astype(np.int16) graph_dict["senders"] = graph_dict["senders"].astype(np.float64) graph_dict["nodes"] = graph_dict["nodes"].astype(np.float64) graph_dict["edges"] = graph_dict["edges"].astype(np.float64) out = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) for key in ["n_node", "n_edge", "receivers", "senders"]: self.assertEqual(np.int32, getattr(out, key).dtype) for key in ["nodes", "edges"]: self.assertEqual(tf.float64, getattr(out, key).dtype)
def convert_to_graph_data(features, topology): comp_to_id = {'R': 1, 'L': 2, 'C': 3,'V': 4} comp_to_value = {'R': 50, 'L': 1e-3, 'C': 1e-6, 'V':0} graphs_list = [] for circuit in range(len(features.keys())): circuit_features = features["circuit_{}".format(circuit+1)] circuit_topology = topology["circuit_{}".format(circuit+1)] nodes = [] edges = [] senders = [] receivers = [] for sender, receiver in circuit_topology.items(): for i in range(len(receiver)): senders.append(float(sender)) receivers.append(float(receiver[i][1])) edges.append([float(comp_to_id[receiver[i][0][0]]),float(comp_to_value[receiver[i][0][0]])]) maximum = 0 for i in range(len(circuit_topology.keys())): if 'node_{}'.format(i) not in circuit_features: nodes.append([0.0]) else: poles_re = circuit_features['node_{}'.format(i)]['poles_re'] zeros_re = circuit_features['node_{}'.format(i)]['zeros_re'] feat = [] if len(poles_re)>0: feat.append(poles_re) if len(zeros_re)>0: feat.append(zeros_re) feat = [item for sublist in feat for item in sublist] if len(feat) > maximum: maximum = len(feat) nodes.append(feat) for i in range(len(nodes)): while len(nodes[i]) < 7: nodes[i].append(0.0) graph = { "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": [0.0,0.0,0.0] } graphs_list.append(graph) graphs_tuple = utils_np.data_dicts_to_graphs_tuple(graphs_list) graphs_tuple = tree.map_structure(lambda x: tf.constant(x) if x is not None else None, graphs_tuple) return graphs_tuple
def visualize_original_graph(graph_dict, file_name, use_edges=True): """ Creates a visualization of the given graph :param graph_dict: An instance of a graph dictionary :param file_name: The path to save the image :param use_edges: This parameter is not used, only exist so it has the same format as the visualize_graph function """ graphs_nx = utils_np.graphs_tuple_to_networkxs(utils_np.data_dicts_to_graphs_tuple([graph_dict])) plt.figure(1, figsize=(25, 25)) mapping0 = {i: "; ".join([str(n[0]), str(n[1])]) for i, n in enumerate(graph_dict["nodes"])} graphs_nx[0] = nx.relabel_nodes(graphs_nx[0], mapping0) nx.draw_networkx(graphs_nx[0], node_color='red') plt.savefig(file_name) nx.drawing.nx_pydot.write_dot(graphs_nx[0], "{}.dot".format(os.path.splitext(file_name)[0])) plt.show()
def gts_to_graph_tuple(graph_gts, vps, eps, gps, data_type_hint=np.float32): data_dicts = [] try: for graph_gt in graph_gts: features_v = _get_features(graph_gt, vps, "nodes") features_e = _get_features(graph_gt, eps, "edges") features_g = _get_features(graph_gt, gps, "global") data_dict = gt_to_data_dict(graph_gt, features_v, features_e, features_g) data_dicts.append(data_dict) except TypeError: raise ValueError("Could not convert some elements of `graph_gts`. " "Did you pass an iterable of networkx instances?") return data_dicts_to_graphs_tuple(data_dicts)
def test_data_dicts_to_graphs_tuple(self, none_fields): """Fields in `none_fields` will be cleared out.""" for field in none_fields: for graph_dict in self.graphs_dicts_in: if field in graph_dict: if field == "nodes": graph_dict["n_node"] = graph_dict["nodes"].shape[0] graph_dict[field] = None self.reference_graph = self.reference_graph._replace(**{field: None}) if field == "senders": self.reference_graph = self.reference_graph._replace( n_edge=np.zeros_like(self.reference_graph.n_edge)) graphs = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) for field in none_fields: self.assertEqual(None, getattr(graphs, field)) self._assert_graph_equals_np(self.reference_graph, graphs)
def snap2graph(h5file, day, tg, use_tf=False, placeholder=False, name=None, normalize=True): snapstr = 'day' + str(day) + 'tg' + str(tg) if normalize: edges = h5file['nn_edge_features/' + snapstr] nodes = h5file['nn_node_features/' + snapstr] glbls = h5file['nn_glbl_features/' + snapstr] else: edges = h5file['nn_edge_features/' + snapstr] nodes = h5file['node_features/' + snapstr] glbls = h5file['glbl_features/' + snapstr] senders = h5file['senders'] receivers = h5file['receivers'] node_arr = nodes[:] edge_arr = edges[:] glbl_arr = glbls[0] graphdat_dict = { "globals": glbl_arr.astype(np.float), "nodes": node_arr.astype(np.float), "edges": edge_arr.astype(np.float), "senders": senders[:], "receivers": receivers[:], "n_node": node_arr.shape[0], "n_edge": edge_arr.shape[0] } if not use_tf: graphs_tuple = utils_np.data_dicts_to_graphs_tuple([graphdat_dict]) else: if placeholder: name = "placeholders_from_data_dicts" if not name else name graphs_tuple = utils_tf.placeholders_from_data_dicts( [graphdat_dict], name=name) else: name = "tuple_from_data_dicts" if not name else name graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([graphdat_dict], name=name) return graphs_tuple
def cons_allcomm_graph(self, obs, poss, chs=None, act=None): datadict = {} datadict["obs"] = obs datadict["senders"] = [] datadict["receivers"] = [] datadict["edges"] = [] datadict["lsenders"] = [] datadict["lreceivers"] = [] datadict["ledges"] = [] datadict["hedges"] = [] datadict["hsenders"] = [] datadict["hreceivers"] = [] indexs = self.get_indexs(poss) datadict['nodes'] = [[0 for j in range(256)] for i in range(len(obs))] datadict['lnodes'] = [[0 for j in range(32)] for i in range(len(obs))] datadict['hnodes'] = [[0 for j in range(64)] for i in range(len(obs))] datadict["globals"] = [0, 0, 0] datadict['q'] = [[0 for i in range(self.num_actions)] for i in range(len(obs))] datadict['hreceivers'].append(0) datadict['hsenders'].append(0) datadict['hedges'].append([0, 0, 0]) for i in range(len(obs)): for j in range(len(obs)): datadict['senders'].append(i) datadict['receivers'].append(j) datadict['lsenders'].append(j) datadict['lreceivers'].append(i) datadict["ledges"].append([0, 0, 0]) datadict["edges"].append([0, 0, 0]) if len(datadict['edges']) == 0: datadict["edges"] = np.zeros(shape=(0, 3)) if len(datadict["hedges"]) == 0: datadict["hedges"] = np.zeros(shape=(0, 3)) if len(datadict["ledges"]) == 0: datadict["ledges"] = np.zeros(shape=(0, 3)) if act is not None: datadict["act"] = act else: datadict["act"] = [0 for i in range(len(obs))] datadicts = [datadict] graphtuple = utils_np.data_dicts_to_graphs_tuple(datadicts) return graphtuple
def apply_policy(self, session, state, actions): """Get a stochastic policy output derived from node-level outputs. Applies softmax to node-level outputs representing actions, or a uniform distribution if values are out of bounds. Returns the probability distribution.""" batched_graphs, batched_globals, batched_targets = self.model.prepare_data( [(state, actions, None)], 0, 1) batched_tuples = utils_np.data_dicts_to_graphs_tuple(batched_graphs) batch_dict = { get_tensors(self.input_graphs): get_tensors(batched_tuples), self.n_objects: [state.n_objects] } values = session.run( { "policy": self.policy, "logits": self.output_graphs.nodes }, feed_dict=batch_dict) nodes = values["policy"].reshape((-1, )) distribution = nodes[state.n_objects:] if sum(distribution) > 1.01 or sum(distribution) < 0.99: distribution = [1.0 / len(actions)] * len(actions) return distribution
def compute_q_batch(self, session, states, actions): """Compute q-values for a batch of states and matching lists of legal actions.""" state_graphs = [ self.model.state_graph(states[i], actions[i]) for i in range(len(states)) ] state_tuples = utils_np.data_dicts_to_graphs_tuple(state_graphs) batch_dict = { get_tensors(self.input_graphs): get_tensors(state_tuples) } node_outputs = session.run(self.output_graphs.nodes, feed_dict=batch_dict).reshape((-1, )) # node outputs will all be concatenated together, so we need to build a q-vector for each state q_vectors = [] state_start = 0 for i in range(len(states)): n_nodes = state_graphs[i]['n_node'] n_objects = n_nodes - len(actions[i]) q_vectors.append(node_outputs[state_start + n_objects:state_start + n_nodes]) state_start += n_nodes # For each input state, skip the outputs corresponding to objects, returning only action q values return q_vectors
def visualize_graph(graph_dict, file_name, use_edges=True): """ Creates a visualization of the given graph using only its 1 valued nodes and edges :param graph_dict: An instance of a graph dictionary :param file_name: The path to save the image :param use_edges: Whether to take into account the value of the edges, or just the value of the nodes """ graph = {"edges": [], "senders": [], "receivers": [], "nodes": [node[:-1] for node in graph_dict["nodes"] if node[-1] >= 0.5], "globals": [1.0]} for edge, sender, receiver in zip(graph_dict["edges"], graph_dict["senders"], graph_dict["receivers"]): if (edge[-1] == 1. or not use_edges) \ and graph_dict["nodes"][sender][-1] >= 0.5 and graph_dict["nodes"][receiver][-1] >= 0.5: graph["edges"].append(edge[:-1]) graph["senders"].append(graph["nodes"].index(graph_dict["nodes"][sender][:-1])) graph["receivers"].append(graph["nodes"].index(graph_dict["nodes"][receiver][:-1])) graphs_nx = utils_np.graphs_tuple_to_networkxs(utils_np.data_dicts_to_graphs_tuple([graph])) plt.figure(1, figsize=(25, 25)) mapping0 = {i: "; ".join([str(n[0]), str(n[1])]) for i, n in enumerate(graph_dict["nodes"])} graphs_nx[0] = nx.relabel_nodes(graphs_nx[0], mapping0) nx.draw_networkx(graphs_nx[0], node_color='blue') plt.savefig(file_name) nx.drawing.nx_pydot.write_dot(graphs_nx[0], "{}.dot".format(os.path.splitext(file_name)[0])) plt.show()
def create_no_nodefeatures_graphs(features, topology): comp_to_id = {'R': 1, 'L': 2, 'C': 3, 'V': 4} comp_to_value = {'R': 50, 'L': 1e-3, 'C': 1e-6, 'V': 0} graphs_list = [] for circuit in range(len(features.keys())): circuit_features = features["circuit_{}".format(circuit + 1)] circuit_topology = topology["circuit_{}".format(circuit + 1)] nodes = [] edges = [] senders = [] receivers = [] for sender, receiver in circuit_topology.items(): for i in range(len(receiver)): senders.append(float(sender)) receivers.append(float(receiver[i][1])) edges.append([float(comp_to_id[receiver[i][0][0]]), float(comp_to_value[receiver[i][0][0]])]) maximum = 0 for i in range(len(circuit_topology.keys())): nodes.append([0.0]) for i in range(len(nodes)): while len(nodes[i]) < 7: nodes[i].append(0.0) graph = { "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": [0.0, 0.0, 0.0] } graphs_list.append(graph) graphs_tuple = utils_np.data_dicts_to_graphs_tuple(graphs_list) graphs_tuple = tree.map_structure(lambda x: tf.constant(x) if x is not None else None, graphs_tuple) return graphs_tuple