Beispiel #1
0
def data_dicts_to_graphs_tuple(data_dicts, name="data_dicts_to_graphs_tuple"):
  """Creates a `graphs.GraphsTuple` containing tensors from data dicts.

   All dictionaries must have exactly the same set of keys with non-`None`
   values associated to them. Moreover, this set of this key must define a valid
   graph (i.e. if the `EDGES` are `None`, the `SENDERS` and `RECEIVERS` must be
   `None`, and `SENDERS` and `RECEIVERS` can only be `None` both at the same
   time). The values associated with a key must be convertible to `Tensor`s,
   for instance python lists, numpy arrays, or Tensorflow `Tensor`s.

   This method may perform a memory copy.

   The `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to
   `np.int32` type.

  Args:
    data_dicts: An iterable of data dictionaries with keys in `ALL_FIELDS`.
    name: (string, optional) A name for the operation.

  Returns:
    A `graphs.GraphTuple` representing the graphs in `data_dicts`.
  """
  for key in ALL_FIELDS:
    for data_dict in data_dicts:
      data_dict.setdefault(key, None)
  utils_np._check_valid_sets_of_keys(data_dicts)  # pylint: disable=protected-access
  with tf.name_scope(name):
    data_dicts = _to_compatible_data_dicts(data_dicts)
    return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
Beispiel #2
0
def get_graph(input_graphs, index):
  """Indexes into a graph.

  Given a `graphs.GraphsTuple` containing arrays and an index (either
  an `int` or a `slice`), index into the nodes, edges and globals to extract the
  graphs specified by the slice, and returns them into an another instance of a
  `graphs.GraphsTuple` containing `Tensor`s.

  Args:
    input_graphs: A `graphs.GraphsTuple` containing numpy arrays.
    index: An `int` or a `slice`, to index into `graph`. `index` should be
      compatible with the number of graphs in `graphs`.

  Returns:
    A `graphs.GraphsTuple` containing numpy arrays, made of the extracted
      graph(s).

  Raises:
    TypeError: if `index` is not an `int` or a `slice`.
  """
  if isinstance(index, int):
    graph_slice = slice(index, index + 1)
  elif isinstance(index, slice):
    graph_slice = index
  else:
    raise TypeError("unsupported type: %s" % type(index))
  data_dicts = graphs_tuple_to_data_dicts(input_graphs)[graph_slice]
  return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
Beispiel #3
0
def _build_placeholders_from_specs(dtypes,
                                   shapes,
                                   force_dynamic_num_graphs=True):
  """Creates a `graphs.GraphsTuple` of placeholders with `dtypes` and `shapes`.

  The dtypes and shapes arguments are instances of `graphs.GraphsTuple` that
  contain dtypes and shapes, or `None` values for the fields for which no
  placeholder should be created. The leading dimension the nodes and edges are
  dynamic because the numbers of nodes and edges can vary.
  If `force_dynamic_num_graphs` is True, then the number of graphs is assumed to
  be dynamic and all fields leading dimensions are set to `None`.
  If `force_dynamic_num_graphs` is False, then the `GRAPH_NUMBER_FIELDS` leading
  dimensions are statically defined.

  Args:
    dtypes: A `graphs.GraphsTuple` that contains `tf.dtype`s or `None`s.
    shapes: A `graphs.GraphsTuple` that contains `list`s of integers,
      `tf.TensorShape`s, or `None`s.
    force_dynamic_num_graphs: A `bool` that forces the batch dimension to be
      dynamic. Defaults to `True`.

  Returns:
    A `graphs.GraphsTuple` containing placeholders.

  Raises:
    ValueError: The `None` fields in `dtypes` and `shapes` do not match.
  """
  dct = {}
  for field in ALL_FIELDS:
    dtype = getattr(dtypes, field)
    shape = getattr(shapes, field)
    if dtype is None or shape is None:
      if not (shape is None and dtype is None):
        raise ValueError(
            "only one of dtype and shape are None for field {}".format(field))
      dct[field] = None
    elif not shape:
      raise ValueError("Shapes must have at least rank 1")
    else:
      shape = list(shape)
      if field in GRAPH_DATA_FIELDS or force_dynamic_num_graphs:
        shape[0] = None
      dct[field] = tf.placeholder(dtype, shape=shape, name=field)

  return graphs.GraphsTuple(**dct)
Beispiel #4
0
def data_dicts_to_graphs_tuple(data_dicts):
  """Constructs a `graphs.GraphsTuple` from an iterable of data dicts.

  The graphs represented by the `data_dicts` argument are batched to form a
  single instance of `graphs.GraphsTuple` containing numpy arrays.

  Args:
    data_dicts: An iterable of dictionaries with keys `GRAPH_DATA_FIELDS`, plus,
      potentially, a subset of `GRAPH_NUMBER_FIELDS`. The NODES and EDGES fields
      should be numpy arrays of rank at least 2, while the RECEIVERS, SENDERS
      are numpy arrays of rank 1 and same dimension as the EDGES field first
      dimension. The GLOBALS field is a numpy array of rank at least 1.

  Returns:
    An instance of `graphs.GraphsTuple` containing numpy arrays. The
    `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to `np.int32`
    type.
  """
  for key in graphs.GRAPH_DATA_FIELDS:
    for data_dict in data_dicts:
      data_dict.setdefault(key, None)
  _check_valid_sets_of_keys(data_dicts)
  data_dicts = _to_compatible_data_dicts(data_dicts)
  return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
Beispiel #5
0
def get_graph(input_graphs, index, name="get_graph"):
  """Indexes into a graph.

  Given a `graphs.graphsTuple` containing `Tensor`s and an index (either
  an `int` or a `slice`), index into the nodes, edges and globals to extract the
  graphs specified by the slice, and returns them into an another instance of a
  `graphs.graphsTuple` containing `Tensor`s.

  Args:
    input_graphs: A `graphs.GraphsTuple` containing `Tensor`s.
    index: An `int` or a `slice`, to index into `graph`. `index` should be
      compatible with the number of graphs in `graphs`.
    name: (string, optional) A name for the operation.

  Returns:
    A `graphs.GraphsTuple` containing `Tensor`s, made of the extracted
      graph(s).

  Raises:
    TypeError: if `index` is not an `int` or a `slice`.
  """

  def safe_slice_none(value, slice_):
    if value is None:
      return value
    return value[slice_]

  if isinstance(index, int):
    graph_slice = slice(index, index + 1)
  elif isinstance(index, slice):
    graph_slice = index
  else:
    raise TypeError("unsupported type: %s" % type(index))

  start_slice = slice(0, graph_slice.start)

  with tf.name_scope(name):
    start_node_index = tf.reduce_sum(
        input_graphs.n_node[start_slice], name="start_node_index")
    start_edge_index = tf.reduce_sum(
        input_graphs.n_edge[start_slice], name="start_edge_index")
    end_node_index = start_node_index + tf.reduce_sum(
        input_graphs.n_node[graph_slice], name="end_node_index")
    end_edge_index = start_edge_index + tf.reduce_sum(
        input_graphs.n_edge[graph_slice], name="end_edge_index")
    nodes_slice = slice(start_node_index, end_node_index)
    edges_slice = slice(start_edge_index, end_edge_index)

    sliced_graphs_dict = {}

    for field in set(GRAPH_NUMBER_FIELDS) | {"globals"}:
      sliced_graphs_dict[field] = safe_slice_none(
          getattr(input_graphs, field), graph_slice)

    field = "nodes"
    sliced_graphs_dict[field] = safe_slice_none(
        getattr(input_graphs, field), nodes_slice)

    for field in {"edges", "senders", "receivers"}:
      sliced_graphs_dict[field] = safe_slice_none(
          getattr(input_graphs, field), edges_slice)
      if (field in {"senders", "receivers"} and
          sliced_graphs_dict[field] is not None):
        sliced_graphs_dict[field] = sliced_graphs_dict[field] - start_node_index

    return graphs.GraphsTuple(**sliced_graphs_dict)