def get_graph_tuple(nodes, globals=None): """ Helper function to create a dict w/ relevant fields for graph_data :param nodes: (Tensor) batch of node values (batch, num_kpts, node_feature_dims) :param globals: (Tensor) batch of global values (batch, 1, global_feature_dims) :return: graph_tuple: graph.GraphTuple type """ nodes_shape = nodes.shape batch_size = nodes_shape[0] num_nodes = tf.ones([batch_size], dtype=tf.int32) num_edges = tf.ones([batch_size], dtype=tf.int32) # defining num_nodes & num_edges for each sample in a batch of input graphs b_num_nodes = nodes_shape[1] * num_nodes b_num_edges = (nodes_shape[1]**2) * num_edges # rehaping (b, num_nodes, dims) -> (b*num_nodes, dims) nodes = tf.reshape(nodes, [nodes_shape[0] * nodes_shape[1], nodes_shape[2]]) if globals is not None: globals_shape = globals.shape globals = tf.reshape( globals, [globals_shape[0] * globals_shape[1], globals_shape[2]]) graph_tuple = graphs.GraphsTuple(nodes=nodes, globals=globals) else: graph_tuple = graphs.GraphsTuple(nodes=nodes, globals=None, edges=None, n_node=b_num_nodes, n_edge=b_num_edges, senders=None, receivers=None) return graph_tuple
def dtype_shape_from_graphs_tuple(input_graph, with_batch_dim=False,\ with_padding=True, debug=False, with_fixed_size=False): graphs_tuple_dtype = {} graphs_tuple_shape = {} edge_dim_fields = [graphs.EDGES, graphs.SENDERS, graphs.RECEIVERS] for field_name in graphs.ALL_FIELDS: field_sample = getattr(input_graph, field_name) shape = list(field_sample.shape) dtype = field_sample.dtype print(field_name, shape, dtype) if not with_fixed_size and shape and not with_padding: if with_batch_dim: shape[1] = None else: if field_name == graphs.NODES or field_name in edge_dim_fields: shape[0] = None graphs_tuple_dtype[field_name] = dtype graphs_tuple_shape[field_name] = tf.TensorShape(shape) if debug: print(field_name, shape, dtype) return graphs.GraphsTuple(**graphs_tuple_dtype), graphs.GraphsTuple( **graphs_tuple_shape)
def parse_tfrec_function(example_proto): features_description = dict( [(key+"_IN", tf.io.FixedLenFeature([], tf.string)) for key in graphs.ALL_FIELDS] + [(key+"_OUT", tf.io.FixedLenFeature([], tf.string)) for key in graphs.ALL_FIELDS]) example = tf.io.parse_single_example(example_proto, features_description) input_dd = graphs.GraphsTuple(**dict([(key, tf.io.parse_tensor(example[key+"_IN"], graph_types[key])) for key in graphs.ALL_FIELDS])) out_dd = graphs.GraphsTuple(**dict([(key, tf.io.parse_tensor(example[key+"_OUT"], graph_types[key])) for key in graphs.ALL_FIELDS])) return input_dd, out_dd
def test_map_field_default_value(self): """Tests the default value for the `fields` argument.""" graph = graphs.GraphsTuple(**self.graph) mapped_fields = [] graph = graph.map(mapped_fields.append) self.assertListEqual(sorted(mapped_fields), sorted([graphs.EDGES, graphs.GLOBALS, graphs.NODES]))
def object_encoder_graph(self, encodings_per_object): """Function to make a graph object from the object oriented encoding Inputs: encodings_per_object: tensor containing the embeddings per object slot (size : [batch, n_objects, embedding dim]) Outputs: graph_representation: fully conected graph with as node attributes the object encodings (size : [batch, graph]) """ # Specify the number of nodes and edges implied by to the passed n_objects and the encodigs for each object n_nodes = tf.tile(tf.constant([self.n_objects]), tf.shape(encodings_per_object)[0:1], name='n_nodes') n_edges = tf.tile(tf.constant([0]), tf.shape(encodings_per_object)[0:1], name='n_edges') # put the node_attributes in the correct shape to make graph object: node_attributes = tf.reshape(encodings_per_object, [ tf.shape(encodings_per_object)[0] * tf.shape(encodings_per_object)[1], self.state_dim_embedding ]) # make graph object with specified node attributes: graph = graphs.GraphsTuple(nodes=node_attributes, edges=None, globals=None, receivers=None, senders=None, n_node=n_nodes, n_edge=n_edges) # Connect all node to the other nodes (i.e. make fully connected): fully_connected_graph = utils_tf.fully_connect_graph_dynamic( graph, exclude_self_edges=False) # Make it runnable in TF because None's are used: runnable_fc_graph = utils_tf.make_runnable_in_session( fully_connected_graph) return runnable_fc_graph
def _build(self, input_op): receivers = input_op.receivers senders = input_op.senders n_node = input_op.n_node n_edge = input_op.n_edge latent = self._encoder(input_op) output_ops = [] for i in range(self._num_processing_steps): for j in range(self._proc_hops[i]): latent = self._core(latent) decoded_op = self._decoder(latent) output_ops.append(decoded_op) stacked_edges = tf.stack([g.edges for g in output_ops], axis=1) stacked_nodes = tf.stack([g.nodes for g in output_ops], axis=1) stacked_globals = tf.stack([g.globals for g in output_ops], axis=1) stacked_globals = tf.reshape(stacked_globals, (-1, self._n_stacked)) stacked_edges = tf.reshape(stacked_edges, (-1, self._n_stacked)) stacked_nodes = tf.reshape(stacked_nodes, (-1, self._n_stacked)) feature_graph = graphs.GraphsTuple(nodes=stacked_nodes, edges=stacked_edges, globals=stacked_globals, receivers=receivers, senders=senders, n_node=n_node, n_edge=n_edge) return self._output_transform(feature_graph)
def generator(): labeled_indices = arrays[f"{prefix}_indices"] if ratio_unlabeled_data_to_labeled_data > 0: num_unlabeled_data_to_add = int( ratio_unlabeled_data_to_labeled_data * labeled_indices.shape[0]) unlabeled_indices = np.random.choice( NUM_PAPERS, size=num_unlabeled_data_to_add, replace=False) root_node_indices = np.concatenate( [labeled_indices, unlabeled_indices]) else: root_node_indices = labeled_indices if shuffle_indices: root_node_indices = root_node_indices.copy() np.random.shuffle(root_node_indices) for index in root_node_indices: graph = sub_sampler.subsample_graph( index, arrays["author_institution_index"], arrays["institution_author_index"], arrays["author_paper_index"], arrays["paper_author_index"], arrays["paper_paper_index"], arrays["paper_paper_index_t"], paper_years=arrays["paper_year"], max_nodes=max_nodes, max_edges=max_edges, **subsampler_kwargs) graph = add_nodes_label(graph, arrays["paper_label"]) graph = add_nodes_year(graph, arrays["paper_year"]) graph = tf_graphs.GraphsTuple(*graph) yield graph
def get_graph(input_graphs, index): """Indexes into a graph. Given a `graphs.GraphsTuple` containing arrays and an index (either an `int` or a `slice`), index into the nodes, edges and globals to extract the graphs specified by the slice, and returns them into an another instance of a `graphs.GraphsTuple` containing `Tensor`s. Args: input_graphs: A `graphs.GraphsTuple` containing numpy arrays. index: An `int` or a `slice`, to index into `graph`. `index` should be compatible with the number of graphs in `graphs`. Returns: A `graphs.GraphsTuple` containing numpy arrays, made of the extracted graph(s). Raises: TypeError: if `index` is not an `int` or a `slice`. """ if isinstance(index, int): graph_slice = slice(index, index + 1) elif isinstance(index, slice): graph_slice = index else: raise TypeError("unsupported type: %s" % type(index)) data_dicts = graphs_tuple_to_data_dicts(input_graphs)[graph_slice] return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
def data_dicts_to_graphs_tuple(data_dicts, name="data_dicts_to_graphs_tuple"): """Creates a `graphs.GraphsTuple` containing tensors from data dicts. All dictionaries must have exactly the same set of keys with non-`None` values associated to them. Moreover, this set of this key must define a valid graph (i.e. if the `EDGES` are `None`, the `SENDERS` and `RECEIVERS` must be `None`, and `SENDERS` and `RECEIVERS` can only be `None` both at the same time). The values associated with a key must be convertible to `Tensor`s, for instance python lists, numpy arrays, or Tensorflow `Tensor`s. This method may perform a memory copy. The `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to `np.int32` type. Args: data_dicts: An iterable of data dictionaries with keys in `ALL_FIELDS`. name: (string, optional) A name for the operation. Returns: A `graphs.GraphTuple` representing the graphs in `data_dicts`. """ for key in ALL_FIELDS: for data_dict in data_dicts: data_dict.setdefault(key, None) utils_np._check_valid_sets_of_keys(data_dicts) # pylint: disable=protected-access with tf.name_scope(name): data_dicts = _to_compatible_data_dicts(data_dicts) return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
def data_dicts_to_graphs_tuple(data_dicts): """Constructs a `graphs.GraphsTuple` from an iterable of data dicts. The graphs represented by the `data_dicts` argument are batched to form a single instance of `graphs.GraphsTuple` containing numpy arrays. Args: data_dicts: An iterable of dictionaries with keys `GRAPH_DATA_FIELDS`, plus, potentially, a subset of `GRAPH_NUMBER_FIELDS`. The NODES and EDGES fields should be numpy arrays of rank at least 2, while the RECEIVERS, SENDERS are numpy arrays of rank 1 and same dimension as the EDGES field first dimension. The GLOBALS field is a numpy array of rank at least 1. Returns: An instance of `graphs.GraphsTuple` containing numpy arrays. The `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to `np.int32` type. """ data_dicts = [dict(d) for d in data_dicts] for key in graphs.GRAPH_DATA_FIELDS: for data_dict in data_dicts: data_dict.setdefault(key, None) _check_valid_sets_of_keys(data_dicts) data_dicts = _to_compatible_data_dicts(data_dicts) return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
def test_self_attention(self): # Just one feature per node. values_np = np.arange(sum(self.N_NODE)) + 1. # Multiple heads, one positive values, one negative values. values_np = np.stack([values_np, values_np*-1.], axis=-1) # Multiple features per node, per head, at different scales. values_np = np.stack([values_np, values_np*0.1], axis=-1) values = tf.constant(values_np, dtype=tf.float32) keys_np = [ [[0.3, 0.4]]*2, # Irrelevant (only sender to one node) [[0.1, 0.5]]*2, # Not used (is not a sender) [[1, 0], [0, 1]], [[0, 1], [1, 0]], [[1, 1], [1, 1]], [[0.4, 0.3]]*2, # Not used (is not a sender) [[0.3, 0.2]]*2] # Not used (is not a sender) keys = tf.constant(keys_np, dtype=tf.float32) queries_np = [ [[0.2, 0.7]]*2, # Not used (is not a receiver) [[0.3, 0.2]]*2, # Irrelevant (only receives from one node) [[0.2, 0.8]]*2, # Not used (is not a receiver) [[0.2, 0.4]]*2, # Not used (is not a receiver) [[0.3, 0.9]]*2, # Not used (is not a receiver) [[0, np.log(2)], [np.log(3), 0]], [[np.log(2), 0], [0, np.log(3)]]] queries = tf.constant(queries_np, dtype=tf.float32) attention_graph = graphs.GraphsTuple( nodes=None, edges=None, globals=None, receivers=tf.constant(self.RECEIVERS, dtype=tf.int32), senders=tf.constant(self.SENDERS, dtype=tf.int32), n_node=tf.constant(self.N_NODE, dtype=tf.int32), n_edge=tf.constant(self.N_EDGE, dtype=tf.int32),) self_attention = modules.SelfAttention() output_graph = self_attention(values, keys, queries, attention_graph) mixed_nodes = output_graph.nodes with self.test_session() as sess: mixed_nodes_output = sess.run(mixed_nodes) expected_mixed_nodes = [ [[0., 0.], [0., 0.]], # Does not receive any edges [[1., 0.1], [-1., -0.1]], # Only receives from n0. [[0., 0.], [0., 0.]], # Does not receive any edges [[0., 0.], [0., 0.]], # Does not receive any edges [[0., 0.], [0., 0.]], # Does not receive any edges [[11/3, 11/3*0.1], # Head one, receives from n2(1/3) n3(2/3) [-15/4, -15/4*0.1]], # Head two, receives from n2(1/4) n3(3/4) [[20/5, 20/5*0.1], # Head one, receives from n2(2/5) n3(1/5) n4(2/5) [-28/7, -28/7*0.1]], # Head two, receives from n2(3/7) n3(1/7) n4(3/7) ] self.assertAllClose(expected_mixed_nodes, mixed_nodes_output)
def test_map_fields_as_expected(self, fields_to_map): """Tests that the fields are mapped are as expected.""" graph = graphs.GraphsTuple(**self.graph) graph = graph.map(lambda v: v + v, fields_to_map) for field in graphs.ALL_FIELDS: if field in fields_to_map: self.assertEqual(field + field, getattr(graph, field)) else: self.assertEqual(field, getattr(graph, field))
def test_replace_with_valid_none_fields(self, none_fields): # Create a graph with different values. graph = graphs.GraphsTuple(**{k: v + v for k, v in self.graph.items()}) # Update with a graph containing the initial values, or Nones. for none_field in none_fields: self.graph[none_field] = None graph = graph.replace(**self.graph) for k, v in self.graph.items(): self.assertEqual(v, getattr(graph, k))
def test_map_field_called_only_once(self): """Tests that the mapping function is called exactly once per field.""" graph = graphs.GraphsTuple(**self.graph) mapped_fields = [] def map_fn(v): mapped_fields.append(v) return v graph = graph.map(map_fn, graphs.ALL_FIELDS) self.assertListEqual(sorted(mapped_fields), sorted(graphs.ALL_FIELDS))
def _get_shaped_input_graph(self): return graphs.GraphsTuple( nodes=tf.zeros([3, 4, 5, 11], dtype=tf.float32), edges=tf.zeros([5, 4, 5, 12], dtype=tf.float32), globals=tf.zeros([2, 4, 5, 13], dtype=tf.float32), receivers=tf.range(5, dtype=tf.int32) // 3, senders=tf.range(5, dtype=tf.int32) % 3, n_node=tf.constant([2, 1], dtype=tf.int32), n_edge=tf.constant([3, 2], dtype=tf.int32), )
def _get_shaped_input_graph(self): return graphs.GraphsTuple( nodes=torch.zeros([3, 4, 5, 11], dtype=torch.float32), edges=torch.zeros([5, 4, 5, 12], dtype=torch.float32), globals=torch.zeros([2, 4, 5, 13], dtype=torch.float32), receivers=torch.range(0, 5, dtype=torch.int64) // 3, senders=torch.range(0, 5, dtype=torch.int64) % 3, n_node=torch.tensor([2, 1], dtype=torch.int64), n_edge=torch.tensor([3, 2], dtype=torch.int64), )
def _get_graphs_tuple(self): """Returns a GraphsTuple containing a graph based on the test system.""" return graphs.GraphsTuple( nodes=tf.constant(self._nodes, dtype=tf.float32), edges=tf.constant(self._edges, dtype=tf.float32), globals=tf.constant(np.array([[0.0]]), dtype=tf.float32), receivers=tf.constant(self._receivers, dtype=tf.int32), senders=tf.constant(self._senders, dtype=tf.int32), n_node=tf.constant([len(self._nodes)], dtype=tf.int32), n_edge=tf.constant([len(self._edges)], dtype=tf.int32))
def specs_from_graphs_tuple( graphs_tuple_sample, with_batch_dim=False, dynamic_num_graphs=False, dynamic_num_nodes=True, dynamic_num_edges=True, description_fn=tf.TensorSpec, ): graphs_tuple_description_fields = {} edge_dim_fields = [graphs.EDGES, graphs.SENDERS, graphs.RECEIVERS] for field_name in graphs.ALL_FIELDS: field_sample = getattr(graphs_tuple_sample, field_name) if field_sample is None: raise ValueError( "The `GraphsTuple` field `{}` was `None`. All fields of the " "`GraphsTuple` must be specified to create valid signatures that" "work with `tf.function`. This can be achieved with `input_graph = " "utils_tf.set_zero_{{node,edge,global}}_features(input_graph, 0)`" "to replace None's by empty features in your graph. Alternatively" "`None`s can be replaced by empty lists by doing `input_graph = " "input_graph.replace({{nodes,edges,globals}}=[]). To ensure " "correct execution of the program, it is recommended to restore " "the None's once inside of the `tf.function` by doing " "`input_graph = input_graph.replace({{nodes,edges,globals}}=None)" "".format(field_name)) shape = list(field_sample.shape) dtype = field_sample.dtype # If the field is not None but has no field shape (i.e. it is a constant) # then we consider this to be a replaced `None`. # If dynamic_num_graphs, then all fields have a None first dimension. # If dynamic_num_nodes, then the "nodes" field needs None first dimension. # If dynamic_num_edges, then the "edges", "senders" and "receivers" need # a None first dimension. if shape: if with_batch_dim: shape[1] = None elif (dynamic_num_graphs \ or (dynamic_num_nodes \ and field_name == graphs.NODES) \ or (dynamic_num_edges \ and field_name in edge_dim_fields)): shape[0] = None print(field_name, shape, dtype) graphs_tuple_description_fields[field_name] = description_fn( shape=shape, dtype=dtype) return graphs.GraphsTuple(**graphs_tuple_description_fields)
def get_placeholders(self): # TODO: Make this work with utils_tf for graph_nets for GraphTuple # Build placeholders placeholders = {} for k, v in self.features.items(): use_batch = (k not in GRAPH_KEYS) placeholders.update(v.get_placeholder(batch=use_batch)) # Other placeholders other_keys = (set(self.features.keys()) - set(GRAPH_KEYS)) sample = {k: placeholders[k] for k in other_keys} # Build Graph Tuple sample['graph'] = \ graphs.GraphsTuple(**{ k: placeholders[k] for k in GRAPH_KEYS }) return sample, placeholders
def test_build_placeholders_from_specs(self, none_fields, force_dynamic_num_graphs=False): num_graphs = 3 shapes = graphs.GraphsTuple( nodes=[3, 4], edges=[2], globals=[num_graphs, 4, 6], receivers=[None], senders=[18], n_node=[num_graphs], n_edge=[num_graphs], ) dtypes = graphs.GraphsTuple( nodes=tf.float64, edges=tf.int32, globals=tf.float32, receivers=tf.int64, senders=tf.int64, n_node=tf.int32, n_edge=tf.int64) dtypes = dtypes.map(lambda _: None, none_fields) shapes = shapes.map(lambda _: None, none_fields) placeholders = utils_tf._build_placeholders_from_specs( dtypes, shapes, force_dynamic_num_graphs=force_dynamic_num_graphs) for k in graphs.ALL_FIELDS: placeholder = getattr(placeholders, k) if k in none_fields: self.assertEqual(None, placeholder) else: self.assertEqual(getattr(dtypes, k), placeholder.dtype) if k not in ["n_node", "n_edge", "globals"] or force_dynamic_num_graphs: self.assertAllEqual([None] + getattr(shapes, k)[1:], placeholder.shape.as_list()) else: self.assertAllEqual([num_graphs] + getattr(shapes, k)[1:], placeholder.shape.as_list())
def _build_placeholders_from_specs(dtypes, shapes, force_dynamic_num_graphs=True): """Creates a `graphs.GraphsTuple` of placeholders with `dtypes` and `shapes`. The dtypes and shapes arguments are instances of `graphs.GraphsTuple` that contain dtypes and shapes, or `None` values for the fields for which no placeholder should be created. The leading dimension the nodes and edges are dynamic because the numbers of nodes and edges can vary. If `force_dynamic_num_graphs` is True, then the number of graphs is assumed to be dynamic and all fields leading dimensions are set to `None`. If `force_dynamic_num_graphs` is False, then `N_NODE`, `N_EDGE` and `GLOBALS` leading dimensions are statically defined. Args: dtypes: A `graphs.GraphsTuple` that contains `tf.dtype`s or `None`s. shapes: A `graphs.GraphsTuple` that contains `list`s of integers, `tf.TensorShape`s, or `None`s. force_dynamic_num_graphs: A `bool` that forces the batch dimension to be dynamic. Defaults to `True`. Returns: A `graphs.GraphsTuple` containing placeholders. Raises: ValueError: The `None` fields in `dtypes` and `shapes` do not match. """ dct = {} for field in ALL_FIELDS: dtype = getattr(dtypes, field) shape = getattr(shapes, field) if dtype is None or shape is None: if not (shape is None and dtype is None): raise ValueError( "only one of dtype and shape are None for field {}".format( field)) dct[field] = None elif not shape: raise ValueError("Shapes must have at least rank 1") else: shape = list(shape) if field not in [N_NODE, N_EDGE, GLOBALS ] or force_dynamic_num_graphs: shape[0] = None dct[field] = tf.compat.v1.placeholder(dtype, shape=shape, name=field) return graphs.GraphsTuple(**dct)
def get_placeholder_and_feature(self, batch): del batch # We do not need batch placeholders = {} sample_dict = {} for key, feat in self.features.items(): # Due to how graphs are concatenated we only need batch dimension for # globals, n_node, and n_edge batch = False if key in ['globals', 'n_node', 'n_edge']: batch = True ph, val = feat.get_placeholder_and_feature(batch=batch) placeholders.update(ph) sample_dict[key] = val sample = graphs.GraphsTuple(**sample_dict) return placeholders, sample
def make_graph_from_static_structure( positions, types, box, edge_threshold): """ Returns graph representing the structure of the collective. Each particle is represented by a node in the graph. The particle motion direction is stored as a node feature. Two particles at a distance less than the threshold are connected by an edge. The relative distance vector is stored as an edge feature. Args: positions: particle positions with shape [n_particles, 2]. types: particle motion direction with shape [n_particles]. box: dimensions of the box that contains the particles with shape [2]. edge_threshold: particles at distance less than threshold are connected by a """ # Calculate pairwise relative distances between particles cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :] # Enforces periodic boundary conditions. box_ = box[tf.newaxis, tf.newaxis, :] cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_ cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_ # Calculates adjacency matrix in a sparse format (indices), based on the given # distances and threshold. distances = tf.norm(cross_positions, axis=-1) indices = tf.where(distances < edge_threshold) # Defines graph. nodes = types[:, tf.newaxis] senders = indices[:, 0] receivers = indices[:, 1] edges = tf.gather_nd(cross_positions, indices) va=[[0.03]] return graphs.GraphsTuple( nodes=tf.cast(nodes, tf.float32), n_node=tf.reshape(tf.shape(nodes)[0], [1]), edges=tf.cast(edges, tf.float32), n_edge=tf.reshape(tf.shape(edges)[0], [1]), globals=tf.cast(va, dtype=tf.float32), receivers=tf.cast(receivers, tf.int32), senders=tf.cast(senders, tf.int32) )
def test_received_edges_normalizer(self, logits, expected_normalized, normalizer): graph = graphs.GraphsTuple( nodes=None, edges=logits, globals=None, receivers=tf.constant(self.RECEIVERS, dtype=tf.int32), senders=tf.constant(self.SENDERS, dtype=tf.int32), n_node=tf.constant(self.N_NODE, dtype=tf.int32), n_edge=tf.constant(self.N_EDGE, dtype=tf.int32), ) actual_normalized_edges = modules._received_edges_normalizer( graph, normalizer) with self.test_session() as sess: actual_normalized_edges_output = sess.run(actual_normalized_edges) self.assertAllClose(expected_normalized, actual_normalized_edges_output)
def get_signature(graphs_tuple_sample): from graph_nets import graphs graphs_tuple_description_fields = {} for field_name in graphs.ALL_FIELDS: per_replica_sample = getattr(graphs_tuple_sample, field_name) def spec_from_value(v): shape = list(v.shape) dtype = list(v.dtype) if shape: shape[1] = None return tf.TensorSpec(shape=shape, dtype=dtype) per_replica_spec = tf.distribute.values.PerReplicaSpec( *(spec_from_value(v) for v in per_replica_sample.values) ) graphs_tuple_description_fields[field_name] = per_replica_spec return graphs.GraphsTuple(**graphs_tuple_description_fields)
def _concat_batch_dim(G): """ G is a GraphNtuple Tensor, with additional dimension for batch-size. Concatenate them along the axis for batch """ input_graphs = [] for ibatch in [0, 1]: data_dict = { "nodes": G.nodes[ibatch], "edges": G.edges[ibatch], "receivers": G.receivers[ibatch], 'senders': G.senders[ibatch], 'globals': G.globals[ibatch], 'n_node': G.n_node[ibatch], 'n_edge': G.n_edge[ibatch], } input_graphs.append(graphs.GraphsTuple(**data_dict)) return (tf.add(ibatch, 1), input_graphs) print("{} graphs".format(len(input_graphs))) return utils_tf.concat(input_graphs, axis=0)
def __call__(self, graph, graph_state, graph_mask, edge_kw={}, node_kw={}, global_kw={}): num_of_times = graph.nodes.shape[1] for t in range(num_of_times): graph_t = graphs.GraphsTuple( tf.squeeze(graph.nodes[:, t, :]), tf.squeeze(graph.edges[:, t, :]), tf.squeeze(graph.globals[:, t, :])) graph_mask_t = GraphMasks( tf.squeeze(graph_mask.nodes[:, t, :]), tf.squeeze(graph_mask.edges[:, t, :]), tf.squeeze(graph_mask.globals[:, t, :])) graph_output, edge_state = self._edge_block( graph_t, graph_state.edges, graph_mask_t.edges, **edge_kw) graph_output, node_state = self._node_block( graph_output, graph_state.nodes, graph_mask_t.nodes, **node_kw) graph_output, global_state = self._global_block( graph_output, graph_state.globals, graph_mask_t.globals, **global_kw) graph_state = GraphStates(edge_state, node_state, global_state) return graph_output, graph_state
def populate_test_data(self, max_size): """Populates the class fields with data used for the tests. This creates a batch of graphs with number of nodes from 0 to `num`, number of edges ranging from 1 to `num`, plus an empty graph with no nodes and no edges (so that the total number of graphs is 1 + (num ** (num + 1)). The nodes states, edges states and global states of the graphs are created to have different types and shapes. Those graphs are stored both as dictionaries (in `self.graphs_dicts_in`, without `n_node` and `n_edge` information, and in `self.graphs_dicts_out` with these two fields filled), and a corresponding numpy `graphs.GraphsTuple` is stored in `self.reference_graph`. Args: max_size: The maximum number of nodes and edges (inclusive). """ filt = lambda x: (x[0] > 0) or (x[1] == 0) n_node, n_edge = zip(*list( filter(filt, itertools.product( range(max_size + 1), range(max_size + 1))))) graphs_dicts = [] nodes = [] edges = [] receivers = [] senders = [] globals_ = [] def _make_default_state(shape, dtype): return np.arange(np.prod(shape)).reshape(shape).astype(dtype) for i, (n_node_, n_edge_) in enumerate(zip(n_node, n_edge)): n = _make_default_state([n_node_, 7, 11], "f4") + i * 100. e = _make_default_state([n_edge_, 13, 14], np.float64) + i * 100. + 1000. r = _make_default_state([n_edge_], np.int32) % n_node[i] s = (_make_default_state([n_edge_], np.int32) + 1) % n_node[i] g = _make_default_state([5, 3], "f4") - i * 100. - 1000. nodes.append(n) edges.append(e) receivers.append(r) senders.append(s) globals_.append(g) graphs_dict = dict(nodes=n, edges=e, receivers=r, senders=s, globals=g) graphs_dicts.append(graphs_dict) # Graphs dicts without n_node / n_edge (to be used as inputs). self.graphs_dicts_in = graphs_dicts # Graphs dicts with n_node / n_node (to be checked against outputs). self.graphs_dicts_out = [] for dict_ in self.graphs_dicts_in: completed_dict = dict_.copy() completed_dict["n_node"] = completed_dict["nodes"].shape[0] completed_dict["n_edge"] = completed_dict["edges"].shape[0] self.graphs_dicts_out.append(completed_dict) # pylint: disable=protected-access offset = utils_np._compute_stacked_offsets(n_node, n_edge) # pylint: enable=protected-access self.reference_graph = graphs.GraphsTuple(**dict( nodes=np.concatenate(nodes, axis=0), edges=np.concatenate(edges, axis=0), receivers=np.concatenate(receivers, axis=0) + offset, senders=np.concatenate(senders, axis=0) + offset, globals=np.stack(globals_), n_node=np.array(n_node), n_edge=np.array(n_edge)))
def test_correct_signature( self, dynamic_num_nodes, dynamic_num_edges, dynamic_num_graphs, batched, replace_globals_with_constant): """Tests that the correct spec is created when using different options.""" if batched: input_data_dicts = [self.graphs_dicts[1], self.graphs_dicts[2]] else: input_data_dicts = [self.graphs_dicts[1]] graph = utils_np.data_dicts_to_graphs_tuple(input_data_dicts) num_graphs = len(input_data_dicts) num_edges = sum(graph.n_edge).item() num_nodes = sum(graph.n_node).item() # Manually setting edges and globals fields to give some variety in # testing situations. # Making edges have rank 1 to . graph = graph.replace(edges=np.zeros(num_edges)) # Make a constant field. if replace_globals_with_constant: graph = graph.replace(globals=np.array(0.0, dtype=np.float32)) spec_signature = utils_tf.specs_from_graphs_tuple( graph, dynamic_num_graphs, dynamic_num_nodes, dynamic_num_edges) # Captures if nodes/edges will be dynamic either due to dynamic nodes/edges # or dynamic graphs. dynamic_nodes_or_graphs = dynamic_num_nodes or dynamic_num_graphs dynamic_edges_or_graphs = dynamic_num_edges or dynamic_num_graphs num_edges = None if dynamic_edges_or_graphs else num_edges num_nodes = None if dynamic_nodes_or_graphs else num_nodes num_graphs = None if dynamic_num_graphs else num_graphs if replace_globals_with_constant: expected_globals_shape = [] else: expected_globals_shape = [num_graphs,] + test_utils.GLOBALS_DIMS expected_answer = graphs.GraphsTuple( nodes=tf.TensorSpec( shape=[num_nodes,] + test_utils.NODES_DIMS, dtype=tf.float32), edges=tf.TensorSpec( shape=[num_edges], # Edges were manually replaced to have dim 1. dtype=tf.float64), n_node=tf.TensorSpec( shape=[num_graphs], dtype=tf.int32), n_edge=tf.TensorSpec( shape=[num_graphs], dtype=tf.int32), globals=tf.TensorSpec( shape=expected_globals_shape, dtype=tf.float32), receivers=tf.TensorSpec( shape=[num_edges], dtype=tf.int32), senders=tf.TensorSpec( shape=[num_edges], dtype=tf.int32), ) with self.subTest(name="Correct Type."): self.assertIsInstance(spec_signature, graphs.GraphsTuple) with self.subTest(name="Correct Signature."): self.assertAllEqual(spec_signature, expected_answer)
def get_graph(input_graphs, index, name="get_graph"): """Indexes into a graph. Given a `graphs.graphsTuple` containing `Tensor`s and an index (either an `int` or a `slice`), index into the nodes, edges and globals to extract the graphs specified by the slice, and returns them into an another instance of a `graphs.graphsTuple` containing `Tensor`s. Args: input_graphs: A `graphs.GraphsTuple` containing `Tensor`s. index: An `int` or a `slice`, to index into `graph`. `index` should be compatible with the number of graphs in `graphs`. name: (string, optional) A name for the operation. Returns: A `graphs.GraphsTuple` containing `Tensor`s, made of the extracted graph(s). Raises: TypeError: if `index` is not an `int` or a `slice`. """ def safe_slice_none(value, slice_): if value is None: return value return value[slice_] if isinstance(index, int): graph_slice = slice(index, index + 1) elif isinstance(index, slice): graph_slice = index else: raise TypeError("unsupported type: %s" % type(index)) start_slice = slice(0, graph_slice.start) with tf.name_scope(name): start_node_index = tf.reduce_sum(input_graphs.n_node[start_slice], name="start_node_index") start_edge_index = tf.reduce_sum(input_graphs.n_edge[start_slice], name="start_edge_index") end_node_index = start_node_index + tf.reduce_sum( input_graphs.n_node[graph_slice], name="end_node_index") end_edge_index = start_edge_index + tf.reduce_sum( input_graphs.n_edge[graph_slice], name="end_edge_index") nodes_slice = slice(start_node_index, end_node_index) edges_slice = slice(start_edge_index, end_edge_index) sliced_graphs_dict = {} for field in set(GRAPH_NUMBER_FIELDS) | {"globals"}: sliced_graphs_dict[field] = safe_slice_none( getattr(input_graphs, field), graph_slice) field = "nodes" sliced_graphs_dict[field] = safe_slice_none( getattr(input_graphs, field), nodes_slice) for field in {"edges", "senders", "receivers"}: sliced_graphs_dict[field] = safe_slice_none( getattr(input_graphs, field), edges_slice) if (field in {"senders", "receivers"} and sliced_graphs_dict[field] is not None): sliced_graphs_dict[ field] = sliced_graphs_dict[field] - start_node_index return graphs.GraphsTuple(**sliced_graphs_dict)