예제 #1
0
def data_dicts_to_graphs_tuple(data_dicts, name="data_dicts_to_graphs_tuple"):
    """Creates a `graphs.GraphsTuple` containing tensors from data dicts.

   All dictionaries must have exactly the same set of keys with non-`None`
   values associated to them. Moreover, this set of this key must define a valid
   graph (i.e. if the `EDGES` are `None`, the `SENDERS` and `RECEIVERS` must be
   `None`, and `SENDERS` and `RECEIVERS` can only be `None` both at the same
   time). The values associated with a key must be convertible to `Tensor`s,
   for instance python lists, numpy arrays, or Tensorflow `Tensor`s.

   This method may perform a memory copy.

   The `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to
   `np.int32` type.

  Args:
    data_dicts: An iterable of data dictionaries with keys in `ALL_FIELDS`.
    name: (string, optional) A name for the operation.

  Returns:
    A `graphs.GraphTuple` representing the graphs in `data_dicts`.
  """
    for key in ALL_FIELDS:
        for data_dict in data_dicts:
            data_dict.setdefault(key, None)
    utils_np._check_valid_sets_of_keys(data_dicts)  # pylint: disable=protected-access
    with tf.name_scope(name):
        data_dicts = _to_compatible_data_dicts(data_dicts)
        return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
예제 #2
0
def concat(
    input_graphs,
    axis,
    use_edges=True,
    use_nodes=True,
    use_globals=True,
    name="graph_concat",
):
    if not input_graphs:
        raise ValueError("List argument `input_graphs` is empty")
    utils_np._check_valid_sets_of_keys([gr._asdict() for gr in input_graphs])
    if len(input_graphs) == 1:
        return input_graphs[0]

    with tf.name_scope(name):
        if use_edges:
            edges = _nested_concatenate(input_graphs, EDGES, axis)
        else:
            edges = getattr(input_graphs[0], EDGES)
        if use_nodes:
            nodes = _nested_concatenate(input_graphs, NODES, axis)
        else:
            nodes = getattr(input_graphs[0], NODES)
        if use_globals:
            globals_ = _nested_concatenate(input_graphs, GLOBALS, axis)
        else:
            globals_ = getattr(input_graphs[0], GLOBALS)

        output = input_graphs[0].replace(nodes=nodes,
                                         edges=edges,
                                         globals=globals_)
        if axis != 0:
            return output
        n_node_per_tuple = tf.stack(
            [tf.reduce_sum(gr.n_node) for gr in input_graphs])
        n_edge_per_tuple = tf.stack(
            [tf.reduce_sum(gr.n_edge) for gr in input_graphs])
        offsets = _compute_stacked_offsets(n_node_per_tuple, n_edge_per_tuple)
        n_node = tf.concat([gr.n_node for gr in input_graphs],
                           axis=0,
                           name="concat_n_node")
        n_edge = tf.concat([gr.n_edge for gr in input_graphs],
                           axis=0,
                           name="concat_n_edge")
        receivers = [
            gr.receivers for gr in input_graphs if gr.receivers is not None
        ]
        receivers = receivers or None
        if receivers:
            receivers = tf.concat(receivers, axis,
                                  name="concat_receivers") + offsets
        senders = [gr.senders for gr in input_graphs if gr.senders is not None]
        senders = senders or None
        if senders:
            senders = tf.concat(senders, axis, name="concat_senders") + offsets
        return output.replace(receivers=receivers,
                              senders=senders,
                              n_node=n_node,
                              n_edge=n_edge)
예제 #3
0
def concat(input_graphs, axis, name="graph_concat"):
    """Returns an op that concatenates graphs along a given axis.

  In all cases, the NODES, EDGES and GLOBALS dimension are concatenated
  along `axis` (if a fields is `None`, the concatenation is just a `None`).
  If `axis` == 0, then the graphs are concatenated along the (underlying) batch
  dimension, i.e. the RECEIVERS, SENDERS, N_NODE and N_EDGE fields of the tuples
  are also concatenated together.
  If `axis` != 0, then there is an underlying asumption that the receivers,
  SENDERS, N_NODE and N_EDGE fields of the graphs in `values` should all match,
  but this is not checked by this op.
  The graphs in `input_graphs` should have the same set of keys for which the
  corresponding fields is not `None`.

  Args:
    input_graphs: A list of `graphs.GraphsTuple` objects containing `Tensor`s
      and satisfying the constraints outlined above.
    axis: An axis to concatenate on.
    name: (string, optional) A name for the operation.

  Returns: An op that returns the concatenated graphs.

  Raises:
    ValueError: If `values` is an empty list, or if the fields which are `None`
      in `input_graphs` are not the same for all the graphs.
  """
    if not input_graphs:
        raise ValueError("List argument `input_graphs` is empty")
    utils_np._check_valid_sets_of_keys([gr._asdict() for gr in input_graphs])  # pylint: disable=protected-access
    if len(input_graphs) == 1:
        return input_graphs[0]
    nodes = [gr.nodes for gr in input_graphs if gr.nodes is not None]
    edges = [gr.edges for gr in input_graphs if gr.edges is not None]
    globals_ = [gr.globals for gr in input_graphs if gr.globals is not None]

    with tf.name_scope(name):
        nodes = tf.concat(nodes, axis, name="concat_nodes") if nodes else None
        edges = tf.concat(edges, axis, name="concat_edges") if edges else None
        if globals_:
            globals_ = tf.concat(globals_, axis, name="concat_globals")
        else:
            globals_ = None
        output = input_graphs[0].replace(nodes=nodes,
                                         edges=edges,
                                         globals=globals_)
        if axis != 0:
            return output
        n_node_per_tuple = tf.stack(
            [tf.reduce_sum(gr.n_node) for gr in input_graphs])
        n_edge_per_tuple = tf.stack(
            [tf.reduce_sum(gr.n_edge) for gr in input_graphs])
        offsets = _compute_stacked_offsets(n_node_per_tuple, n_edge_per_tuple)
        n_node = tf.concat([gr.n_node for gr in input_graphs],
                           axis=0,
                           name="concat_n_node")
        n_edge = tf.concat([gr.n_edge for gr in input_graphs],
                           axis=0,
                           name="concat_n_edge")
        receivers = [
            gr.receivers for gr in input_graphs if gr.receivers is not None
        ]
        receivers = receivers or None
        if receivers:
            receivers = tf.concat(receivers, axis,
                                  name="concat_receivers") + offsets
        senders = [gr.senders for gr in input_graphs if gr.senders is not None]
        senders = senders or None
        if senders:
            senders = tf.concat(senders, axis, name="concat_senders") + offsets
        return output.replace(receivers=receivers,
                              senders=senders,
                              n_node=n_node,
                              n_edge=n_edge)