def data_dicts_to_graphs_tuple(data_dicts, name="data_dicts_to_graphs_tuple"): """Creates a `graphs.GraphsTuple` containing tensors from data dicts. All dictionaries must have exactly the same set of keys with non-`None` values associated to them. Moreover, this set of this key must define a valid graph (i.e. if the `EDGES` are `None`, the `SENDERS` and `RECEIVERS` must be `None`, and `SENDERS` and `RECEIVERS` can only be `None` both at the same time). The values associated with a key must be convertible to `Tensor`s, for instance python lists, numpy arrays, or Tensorflow `Tensor`s. This method may perform a memory copy. The `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to `np.int32` type. Args: data_dicts: An iterable of data dictionaries with keys in `ALL_FIELDS`. name: (string, optional) A name for the operation. Returns: A `graphs.GraphTuple` representing the graphs in `data_dicts`. """ for key in ALL_FIELDS: for data_dict in data_dicts: data_dict.setdefault(key, None) utils_np._check_valid_sets_of_keys(data_dicts) # pylint: disable=protected-access with tf.name_scope(name): data_dicts = _to_compatible_data_dicts(data_dicts) return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
def concat( input_graphs, axis, use_edges=True, use_nodes=True, use_globals=True, name="graph_concat", ): if not input_graphs: raise ValueError("List argument `input_graphs` is empty") utils_np._check_valid_sets_of_keys([gr._asdict() for gr in input_graphs]) if len(input_graphs) == 1: return input_graphs[0] with tf.name_scope(name): if use_edges: edges = _nested_concatenate(input_graphs, EDGES, axis) else: edges = getattr(input_graphs[0], EDGES) if use_nodes: nodes = _nested_concatenate(input_graphs, NODES, axis) else: nodes = getattr(input_graphs[0], NODES) if use_globals: globals_ = _nested_concatenate(input_graphs, GLOBALS, axis) else: globals_ = getattr(input_graphs[0], GLOBALS) output = input_graphs[0].replace(nodes=nodes, edges=edges, globals=globals_) if axis != 0: return output n_node_per_tuple = tf.stack( [tf.reduce_sum(gr.n_node) for gr in input_graphs]) n_edge_per_tuple = tf.stack( [tf.reduce_sum(gr.n_edge) for gr in input_graphs]) offsets = _compute_stacked_offsets(n_node_per_tuple, n_edge_per_tuple) n_node = tf.concat([gr.n_node for gr in input_graphs], axis=0, name="concat_n_node") n_edge = tf.concat([gr.n_edge for gr in input_graphs], axis=0, name="concat_n_edge") receivers = [ gr.receivers for gr in input_graphs if gr.receivers is not None ] receivers = receivers or None if receivers: receivers = tf.concat(receivers, axis, name="concat_receivers") + offsets senders = [gr.senders for gr in input_graphs if gr.senders is not None] senders = senders or None if senders: senders = tf.concat(senders, axis, name="concat_senders") + offsets return output.replace(receivers=receivers, senders=senders, n_node=n_node, n_edge=n_edge)
def concat(input_graphs, axis, name="graph_concat"): """Returns an op that concatenates graphs along a given axis. In all cases, the NODES, EDGES and GLOBALS dimension are concatenated along `axis` (if a fields is `None`, the concatenation is just a `None`). If `axis` == 0, then the graphs are concatenated along the (underlying) batch dimension, i.e. the RECEIVERS, SENDERS, N_NODE and N_EDGE fields of the tuples are also concatenated together. If `axis` != 0, then there is an underlying asumption that the receivers, SENDERS, N_NODE and N_EDGE fields of the graphs in `values` should all match, but this is not checked by this op. The graphs in `input_graphs` should have the same set of keys for which the corresponding fields is not `None`. Args: input_graphs: A list of `graphs.GraphsTuple` objects containing `Tensor`s and satisfying the constraints outlined above. axis: An axis to concatenate on. name: (string, optional) A name for the operation. Returns: An op that returns the concatenated graphs. Raises: ValueError: If `values` is an empty list, or if the fields which are `None` in `input_graphs` are not the same for all the graphs. """ if not input_graphs: raise ValueError("List argument `input_graphs` is empty") utils_np._check_valid_sets_of_keys([gr._asdict() for gr in input_graphs]) # pylint: disable=protected-access if len(input_graphs) == 1: return input_graphs[0] nodes = [gr.nodes for gr in input_graphs if gr.nodes is not None] edges = [gr.edges for gr in input_graphs if gr.edges is not None] globals_ = [gr.globals for gr in input_graphs if gr.globals is not None] with tf.name_scope(name): nodes = tf.concat(nodes, axis, name="concat_nodes") if nodes else None edges = tf.concat(edges, axis, name="concat_edges") if edges else None if globals_: globals_ = tf.concat(globals_, axis, name="concat_globals") else: globals_ = None output = input_graphs[0].replace(nodes=nodes, edges=edges, globals=globals_) if axis != 0: return output n_node_per_tuple = tf.stack( [tf.reduce_sum(gr.n_node) for gr in input_graphs]) n_edge_per_tuple = tf.stack( [tf.reduce_sum(gr.n_edge) for gr in input_graphs]) offsets = _compute_stacked_offsets(n_node_per_tuple, n_edge_per_tuple) n_node = tf.concat([gr.n_node for gr in input_graphs], axis=0, name="concat_n_node") n_edge = tf.concat([gr.n_edge for gr in input_graphs], axis=0, name="concat_n_edge") receivers = [ gr.receivers for gr in input_graphs if gr.receivers is not None ] receivers = receivers or None if receivers: receivers = tf.concat(receivers, axis, name="concat_receivers") + offsets senders = [gr.senders for gr in input_graphs if gr.senders is not None] senders = senders or None if senders: senders = tf.concat(senders, axis, name="concat_senders") + offsets return output.replace(receivers=receivers, senders=senders, n_node=n_node, n_edge=n_edge)