def create_data(batch_size, num_elements_min_max): """Returns graphs containing the inputs and targets for classification. Refer to create_data_dicts_tf and create_linked_list_target for more details. Args: batch_size: batch size for the `input_graphs`. num_elements_min_max: a 2-`tuple` of `int`s which define the [lower, upper) range of the number of elements per list. Returns: inputs: a `graphs.GraphsTuple` which contains the input list as a graph. targets: a `graphs.GraphsTuple` which contains the target as a graph. sort_indices: a `graphs.GraphsTuple` which contains the sort indices of the list elements a graph. ranks: a `graphs.GraphsTuple` which contains the ranks of the list elements as a graph. """ inputs, sort_indices, ranks = create_graph_dicts_tf( batch_size, num_elements_min_max) inputs = utils_tf.data_dicts_to_graphs_tuple(inputs) sort_indices = utils_tf.data_dicts_to_graphs_tuple(sort_indices) ranks = utils_tf.data_dicts_to_graphs_tuple(ranks) inputs = utils_tf.fully_connect_graph_dynamic(inputs) sort_indices = utils_tf.fully_connect_graph_dynamic(sort_indices) ranks = utils_tf.fully_connect_graph_dynamic(ranks) targets = create_linked_list_target(batch_size, sort_indices) nodes = tf.concat((targets.nodes, 1.0 - targets.nodes), axis=1) edges = tf.concat((targets.edges, 1.0 - targets.edges), axis=1) targets = targets._replace(nodes=nodes, edges=edges) return inputs, targets, sort_indices, ranks # input node[7,1] edge target edge[49,2] node[7,2]
def make_graph(self, event, debug=False): """ Convert the event into a graphs_tuple. """ edge_name = self.edge_name n_nodes = event['x'].shape[0] n_edges = event[edge_name].shape[1] nodes = event['x'] edges = np.zeros((n_edges, 1), dtype=np.float32) senders = event[edge_name][0, :] receivers = event[edge_name][1, :] edge_target = event[self.truth_name].numpy().astype(np.float32) input_datadict = { "n_node": n_nodes, "n_edge": n_edges, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": np.array([n_nodes], dtype=np.float32) } n_edges_target = 1 target_datadict = { "n_node": 1, "n_edge": n_edges_target, "nodes": np.zeros((1, 1), dtype=np.float32), "edges": edge_target, "senders": np.zeros((n_edges_target, ), dtype=np.int32), "receivers": np.zeros((n_edges_target, ), dtype=np.int32), "globals": np.zeros((1, ), dtype=np.float32), } input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict]) target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict]) return [(input_graph, target_graph)]
def make_graph(self, event, debug): inputs_tr, _ = event # apply the GNN model and filter out the edges with a score less than the threshold 0.5. outputs_tr = self.model(inputs_tr, self.num_mp) output_graph = outputs_tr[-1] # calculate similar variables for GNN-based reconstruction # method-one, place a threshold on edge score edge_predict = np.squeeze(output_graph.edges.numpy()) edge_passed = edge_predict > self.edge_cut nodes_sel = np.unique(np.concatenate([output_graph.receivers.numpy()[edge_passed],\ output_graph.senders.numpy()[edge_passed]], axis=0)) n_nodes = nodes_sel.shape[0] n_edges = sum(edge_passed) nodes = inputs_tr.nodes.numpy()[nodes_sel] edges = inputs_tr.edges.numpy()[edge_passed] node_dicts = {} for idx, val in enumerate(nodes_sel): node_dicts[val] = idx senders = np.array( [node_dicts[x] for x in inputs_tr.senders.numpy()[edge_passed]]) receivers = np.array( [node_dicts[x] for x in inputs_tr.receivers.numpy()[edge_passed]]) # print("n-nodes:", n_nodes) # print("n-edges:", n_edges) # print("nodes:", nodes.shape) # print("edges:", edges.shape) # print("senders:", senders.shape) # print("receivers:", receivers.shape) # print(senders) # print(receivers) input_datadict = { "n_node": n_nodes, "n_edge": n_edges, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": np.array([n_nodes], dtype=np.float32) } target_datadict = { "n_node": n_nodes, "n_edge": n_edges, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": np.array([float(self.is_signal)], dtype=np.float32) } input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict]) target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict]) return [(input_graph, target_graph)]
def make_subgraph(mask): hit_id = hits[mask].hit_id.values sub_doublets = segments[segments.hit_id_in.isin(hit_id) & segments.hit_id_out.isin(hit_id)] # TODO: include all edges, uncomment following lines. <> # sub_doublets = segments[segments.hit_id_in.isin(hit_id)] # # extend the hits to include the hits used in the sub-doublets. # hit_id = hits[mask | hits.hit_id.isin(sub_doublets.hit_id_out.values)].hit_id.values n_nodes = hit_id.shape[0] n_edges = sub_doublets.shape[0] nodes = hits[mask][node_features].values.astype(f_dtype) if edge_features is None: edges = np.expand_dims(np.array([0.0]*n_edges, dtype=np.float32), axis=1) else: edges = sub_doublets[edge_features].values.astype(f_dtype) # print(nodes.dtype) hits_id_dict = dict([(hit_id[idx], idx) for idx in range(n_nodes)]) in_hit_ids = sub_doublets.hit_id_in.values out_hit_ids = sub_doublets.hit_id_out.values senders = [hits_id_dict[in_hit_ids[idx]] for idx in range(n_edges)] receivers = [hits_id_dict[out_hit_ids[idx]] for idx in range(n_edges)] if verbose: print("\t{} nodes and {} edges".format(n_nodes, n_edges)) senders = np.array(senders) receivers = np.array(receivers) input_datadict = { "n_node": n_nodes, 'n_edge': n_edges, 'nodes': nodes, 'edges': edges, 'senders': senders, 'receivers': receivers, 'globals': zeros } target_datadict = { "n_node": n_nodes, 'n_edge': n_edges, 'nodes': np.zeros((n_nodes, n_node_features), dtype=f_dtype), 'edges': np.expand_dims(sub_doublets.solution.values.astype(f_dtype), axis=1), 'senders': senders, 'receivers': receivers, 'globals': zeros } input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict]) target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict]) if with_pad: input_graph = padding(input_graph, N_MAX_NODES, N_MAX_EDGES) target_graph = padding(target_graph, N_MAX_NODES, N_MAX_EDGES) return [(input_graph, target_graph)]
def make_graph(event, debug=False): # n_max_nodes = 60 n_nodes = len(event.jet_pt) nodes = np.hstack( (event.jet_pt, event.jet_eta, event.jet_phi, event.jet_e)) nodes = nodes.reshape((n_nodes, 4)) if debug: print(np.array(event.jet_pt).shape) print(n_nodes) print(nodes) # edges all_edges = list(itertools.combinations(range(n_nodes), 2)) senders = np.array([x[0] for x in all_edges]) receivers = np.array([x[1] for x in all_edges]) n_edges = len(all_edges) edges = np.expand_dims(np.array([0.0] * n_edges, dtype=np.float32), axis=1) true_edges = set(list(itertools.combinations(event.reco_triplet_0, 2)) \ + list(itertools.combinations(event.reco_triplet_1, 2))) truth_labels = [int(x in true_edges) for x in all_edges] if debug: print(all_edges) print(event.reco_triplet_0) print(event.reco_triplet_1) print(truth_labels) truth_labels = np.array(truth_labels, dtype=np.float32) zeros = np.array([0.0], dtype=np.float32) input_datadict = { "n_node": n_nodes, "n_edge": n_edges, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": np.array([n_nodes], dtype=np.float32) } target_datadict = { "n_node": n_nodes, "n_edge": n_edges, "nodes": zeros, "edges": truth_labels, "senders": senders, "receivers": receivers, "globals": zeros } input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict]) target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict]) return [(input_graph, target_graph)]
def make_graph(event, debug: Optional[bool] = False): n_max_nodes = 200 n_nodes = 0 nodes = [] for inode in range(n_max_nodes): E_name = 'E_{}'.format(inode) if event[E_name] < 0.1: continue f_keynames = ['{}_{}'.format(x, inode) for x in features] n_nodes += 1 nodes.append(event[f_keynames].values * scale) nodes = np.array(nodes, dtype=np.float32) # print(n_nodes, "nodes") # print("node features:", nodes.shape) # edges 1) fully connected, 2) objects nearby in eta/phi are connected # TODO: implement 2). <xju> all_edges = list(itertools.combinations(range(n_nodes), 2)) senders = np.array([x[0] for x in all_edges]) receivers = np.array([x[1] for x in all_edges]) n_edges = len(all_edges) edges = np.expand_dims(np.array([0.0] * n_edges, dtype=np.float32), axis=1) # print(n_edges, "edges") # print("senders:", senders) # print("receivers:", receivers) input_datadict = { "n_node": n_nodes, "n_edge": n_edges, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": np.array([n_nodes], dtype=np.float32) } target_datadict = { "n_node": n_nodes, "n_edge": n_edges, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": np.array([event[solution]], dtype=np.float32) } input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict]) target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict]) return [(input_graph, target_graph)]
def test_unpack_graphs_tuple(self): graph_0 = { graphs.N_NODE: 4, graphs.N_EDGE: 3, graphs.NODES: [[0, 0.1], [1, 1.1], [2, 2.1], [3, 3.1]], graphs.EDGES: [[1, 1.2], [2, 2.2], [3, 3.2]], graphs.SENDERS: [0, 0, 0], graphs.RECEIVERS: [1, 2, 3], } graph_1 = { graphs.N_NODE: 3, graphs.N_EDGE: 2, graphs.NODES: [[0, 0.1], [1, 1.1], [2, 2.1]], graphs.EDGES: [[1, 1.2], [2, 2.2]], graphs.SENDERS: [0, 1], graphs.RECEIVERS: [1, 2], } graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([graph_0, graph_1]) nodes, edges = graph_networks.GraphNet._unpack_graphs_tuple( graphs_tuple=graphs_tuple, max_n_node=5, max_n_edge=6) self.assertAllClose(nodes, [[[0, 0.1], [1, 1.1], [2, 2.1], [3, 3.1], [0, 0]], [[0, 0.1], [1, 1.1], [2, 2.1], [0, 0], [0, 0]]]) self.assertAllClose( edges, [[[1, 1.2], [2, 2.2], [3, 3.2], [0, 0], [0, 0], [0, 0]], [[1, 1.2], [2, 2.2], [0, 0], [0, 0], [0, 0], [0, 0]]])
def test_output_values(self, broadcaster, expected): """Test the broadcasted output value.""" input_graph = utils_tf.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2]) broadcasted_out = broadcaster(input_graph) self.assertNDArrayNear( np.array(expected, dtype=np.float32), broadcasted_out, err=1e-4)
def test_field_must_not_be_none(self, none_fields): """Tests that the model cannot be built if required fields are missing.""" input_graph = utils_tf.data_dicts_to_graphs_tuple([SMALL_GRAPH_1]) input_graph = input_graph.map(lambda _: None, none_fields) relation_network = self._get_model() with self.assertRaises(ValueError): relation_network(input_graph)
def test_field_must_not_be_none(self, none_field): """Tests that the model cannot be built if required fields are missing.""" input_graph = utils_tf.data_dicts_to_graphs_tuple([SMALL_GRAPH_1]) input_graph = input_graph.replace(**{none_field: None}) deep_sets = self._get_model() with self.assertRaises(ValueError): deep_sets(input_graph)
def _get_input_graph(self, none_fields=None): if none_fields is None: none_fields = [] input_graph = utils_tf.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2, SMALL_GRAPH_3, SMALL_GRAPH_4]) input_graph = input_graph.map(lambda _: None, none_fields) return input_graph
def test_gns_dataset(self): if sys.platform != "win32": self.assertTrue(True) return experience_dir = os.path.join("../experience", "data") case_name = "rte_case5_example" agent_name = "agent-mip" env_dc = True case, collector = load_experience(case_name, agent_name, experience_dir, env_dc=env_dc) obses, actions, rewards, dones = collector.aggregate_data() n_batch = 16 max_length = 10 * n_batch + 1 n_window = 2 graphs_dict_list = obses_to_lgraphs(obses, dones, case, max_length=max_length, n_window=n_window) cgraphs = lgraphs_to_cgraphs(graphs_dict_list) labels = is_do_nothing_action(actions, case.env) graph_dims = get_graph_feature_dimensions(cgraphs=cgraphs) graph_dataset = tf_batched_graph_dataset(cgraphs, n_batch=n_batch, **graph_dims) label_dataset = tf.data.Dataset.from_tensor_slices(labels).batch( n_batch) dataset = tf.data.Dataset.zip((graph_dataset, label_dataset)) dataset = dataset.repeat(1) for batch_idx, (graph_batch, label_batch) in enumerate(dataset): graph_batch_from_list = utils_tf.data_dicts_to_graphs_tuple( graphs_dict_list[(n_batch * batch_idx):(n_batch * (batch_idx + 1))]) check = tf.squeeze(equal_graphs(graph_batch, graph_batch_from_list)).numpy() pprint("Batch:", batch_idx, check) if not check: for field in [ "globals", "nodes", "edges", ]: print_matrix(getattr(graph_batch, field)) print_matrix(getattr(graph_batch_from_list, field)) self.assertTrue(check)
def test_raises(self, index, error_type, message, use_constant, use_slice): if use_constant: index = tf.constant(index) if use_slice: index = slice(index) graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) with self.assertRaisesRegexp(error_type, message): utils_tf.get_graph(graphs_tuple, index)
def test_fill_global_state(self, global_size): """Tests for filling the global state with a constant content.""" for g in self.graphs_dicts_in: g.pop("globals") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) n_graphs = self.reference_graph.n_edge.shape[0] graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size) self.assertAllEqual((n_graphs, global_size), graphs_tuple.globals.get_shape().as_list())
def test_fill_edge_state(self, edge_size): """Tests for filling the edge state with a constant content.""" for g in self.graphs_dicts_in: g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) n_edges = np.sum(self.reference_graph.n_edge) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size) self.assertAllEqual((n_edges, edge_size), graphs_tuple.edges.get_shape().as_list())
def setUp(self): super(RunGraphWithNoneInSessionTest, self).setUp() self._graph = utils_tf.data_dicts_to_graphs_tuple([{ "senders": tf.random_uniform([10], maxval=10, dtype=tf.int32), "receivers": tf.random_uniform([10], maxval=10, dtype=tf.int32), "nodes": tf.random_uniform([5, 7]), "edges": tf.random_uniform([10, 6]), "globals": tf.random_uniform([1, 8]) }])
def test_fill_edge_state_with_missing_fields_raises(self): """Edge field cannot be filled if receivers or senders are missing.""" for g in self.graphs_dicts_in: g.pop("receivers") g.pop("senders") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) with self.assertRaisesRegexp(ValueError, "receivers"): graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1)
def get_empty_graph(nodeshape, edgeshape, glblshape, senders, receivers): dic = { "globals": np.zeros(glblshape, dtype=np.float), "nodes": np.zeros(nodeshape, dtype=np.float), "edges": np.zeros(edgeshape, dtype=np.float), "senders": senders, "receivers": receivers } return utils_tf.data_dicts_to_graphs_tuple([dic])
def setUp(self): super(TestNestToNumpy, self).setUp() self._graph = utils_tf.data_dicts_to_graphs_tuple([{ "senders": tf.random.uniform([10], maxval=10, dtype=tf.int32), "receivers": tf.random.uniform([10], maxval=10, dtype=tf.int32), "nodes": tf.random.uniform([5, 7]), "edges": tf.random.uniform([10, 6]), "globals": tf.random.uniform([1, 8]) }])
def setUp(self): super(StopGradientsGraphTest, self).setUp() self._graph = utils_tf.data_dicts_to_graphs_tuple([{ "senders": tf.zeros([10], dtype=tf.int32), "receivers": tf.zeros([10], dtype=tf.int32), "nodes": tf.ones([5, 7]), "edges": tf.zeros([10, 6]), "globals": tf.zeros([1, 8]) }])
def test_fill_node_state(self, node_size): """Tests for filling the node state with a constant content.""" for g in self.graphs_dicts_in: g["n_node"] = g["nodes"].shape[0] g.pop("nodes") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) n_nodes = np.sum(self.reference_graph.n_node) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size) self.assertAllEqual((n_nodes, node_size), graphs_tuple.nodes.get_shape().as_list())
def load_batch(self, mode, repeat=None): """Return batch loaded from this dataset""" params = self.dataset_params opts = self.opts assert mode in params.sizes, "Mode {} not supported".format(mode) other_keys = list(set(self.features.keys()) - set(GRAPH_KEYS)) item_keys = GRAPH_KEYS + other_keys data_source_name = mode + '-*.tfrecords' data_sources = glob.glob( os.path.join(self.data_dir, mode, data_source_name)) if opts.shuffle_data and mode != 'test': np.random.shuffle(data_sources) # Added to help the shuffle # Build dataset provider keys_to_features = {} for k, v in self.features.items(): keys_to_features.update(v.get_feature_read()) items_to_descriptions = { k: v.description for k, v in self.features.items() } def parser_op(record): example = tf.parse_single_example(record, keys_to_features) return [ self.features[k].tensors_to_item(example) for k in item_keys ] dataset = tf.data.TFRecordDataset(data_sources) dataset = dataset.map(parser_op) dataset = dataset.repeat(repeat) if opts.shuffle_data and mode != 'test': dataset = dataset.shuffle(buffer_size=5 * opts.batch_size) # dataset = dataset.prefetch(buffer_size=opts.batch_size) iterator = dataset.make_one_shot_iterator() batch_graphs = [] batch_other = [] for b in range(opts.batch_size): sample_ = iterator.get_next() # Extracting other keys outside of the graph sample_other_ = { k: sample_[i + len(GRAPH_KEYS)] for i, k in enumerate(other_keys) } batch_other.append(sample_other_) # Constructing graph using relevant graph keys sample_graph = {k: sample_[i] for i, k in enumerate(GRAPH_KEYS)} batch_graphs.append(sample_graph) # Constructing output sample using known order of the keys sample = {} for k in other_keys: sample[k] = self.features[k].stack( [batch_other[b][k] for b in range(opts.batch_size)]) sample['graph'] = utils_tf.data_dicts_to_graphs_tuple(batch_graphs) return sample
def test_output_values_larger_rank(self, broadcaster, expected): """Test the broadcasted output value.""" input_graph = utils_tf.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2]) input_graph = input_graph.map( lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1])) broadcasted_out = broadcaster(input_graph) self.assertNDArrayNear(np.reshape(np.array(expected, dtype=np.float32), [len(expected)] + [2, -1]), broadcasted_out, err=1e-4)
def test_fill_edge_state_dynamic(self, edge_size): """Tests for filling the edge state with a constant content.""" for g in self.graphs_dicts_in: g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = graphs_tuple._replace( n_edge=tf.constant( graphs_tuple.n_edge, shape=graphs_tuple.n_edge.get_shape())) n_edges = np.sum(self.reference_graph.n_edge) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size) actual_edges = graphs_tuple.edges self.assertNDArrayNear( np.zeros((n_edges, edge_size)), actual_edges, err=1e-4)
def test_fill_state_user_specified_types(self, dtype): """Tests that the features are created with the correct default type.""" for g in self.graphs_dicts_in: g.pop("nodes") g.pop("globals") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, 1, dtype) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, 1, dtype) graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, 1, dtype) self.assertEqual(dtype, graphs_tuple.edges.dtype) self.assertEqual(dtype, graphs_tuple.nodes.dtype) self.assertEqual(dtype, graphs_tuple.globals.dtype)
def get_batched_graphs (train_set): """ description: converts inputs in each batch to complete graphs :param train_set: training set containing tuples (batch_input , batch_target) :return: """ for batch_input , batch_target in train_set: input_dict = create_graph_dicts(batch_input) targets = batch_target input_dict = utils_tf.data_dicts_to_graphs_tuple(input_dict) input_dict = utils_tf.fully_connect_graph_dynamic(input_dict) yield input_dict , targets
def test_fill_state_default_types(self): """Tests that the features are created with the correct default type.""" for g in self.graphs_dicts_in: g.pop("nodes") g.pop("globals") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size=1) graphs_tuple = utils_tf.set_zero_global_features( graphs_tuple, global_size=1) self.assertEqual(tf.float32, graphs_tuple.edges.dtype) self.assertEqual(tf.float32, graphs_tuple.nodes.dtype) self.assertEqual(tf.float32, graphs_tuple.globals.dtype)
def test_fill_node_state_dynamic(self, node_size): """Tests for filling the node state with a constant content.""" for g in self.graphs_dicts_in: g["n_node"] = g["nodes"].shape[0] g.pop("nodes") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = graphs_tuple._replace( n_node=tf.constant( graphs_tuple.n_node, shape=graphs_tuple.n_node.get_shape())) n_nodes = np.sum(self.reference_graph.n_node) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size) actual_nodes = graphs_tuple.nodes.numpy() self.assertNDArrayNear( np.zeros((n_nodes, node_size)), actual_nodes, err=1e-4)
def test_fill_global_state_dynamic(self, global_size): """Tests for filling the global state with a constant content.""" for g in self.graphs_dicts_in: g.pop("globals") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) # Hide global shape information graphs_tuple = graphs_tuple._replace( n_node=tf.placeholder_with_default(graphs_tuple.n_node, shape=[None])) n_graphs = self.reference_graph.n_edge.shape[0] graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size) with self.test_session() as sess: actual_globals = sess.run(graphs_tuple.globals) self.assertNDArrayNear( np.zeros((n_graphs, global_size)), actual_globals, err=1e-4)
def load_graphs(file_path, train_ratio, keep_features, existence_as_vector=True): """ The function extracts the graphs from the given file :param file_path: The path to the graph json file :param train_ratio: How much of the data should we use for training. The other part is used for testing. :param keep_features: Whether to keep all features of the graph. It is advised to do so in case of input graphs. :param existence_as_vector: Whether to represent existence feature as a vector with the length of 2 in an individual feature vector. It should be False when processing the input graphs. :return: Training and testing GraphTuples """ graph_dicts = [] with open(file_path) as json_file: line = json_file.readline().strip() while line != '' and line is not None: json_dict = json.loads(line) json_dict = process_line(json_dict, keep_features, existence_as_vector) is_valid_graph(json_dict) graph_dicts.append(json_dict) line = json_file.readline().strip() train_dicts = graph_dicts[:int(train_ratio * len(graph_dicts))] test_dicts = graph_dicts[int(train_ratio * len(graph_dicts)):] graphs_tuple_train = utils_tf.data_dicts_to_graphs_tuple(train_dicts) graphs_tuple_test = utils_tf.data_dicts_to_graphs_tuple(test_dicts) return graphs_tuple_train, graphs_tuple_test