def data_dicts_to_graphs_tuple(data_dicts, name="data_dicts_to_graphs_tuple"): """Creates a `graphs.GraphsTuple` containing tensors from data dicts. All dictionaries must have exactly the same set of keys with non-`None` values associated to them. Moreover, this set of this key must define a valid graph (i.e. if the `EDGES` are `None`, the `SENDERS` and `RECEIVERS` must be `None`, and `SENDERS` and `RECEIVERS` can only be `None` both at the same time). The values associated with a key must be convertible to `Tensor`s, for instance python lists, numpy arrays, or Tensorflow `Tensor`s. This method may perform a memory copy. The `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to `np.int32` type. Args: data_dicts: An iterable of data dictionaries with keys in `ALL_FIELDS`. name: (string, optional) A name for the operation. Returns: A `graphs.GraphTuple` representing the graphs in `data_dicts`. """ for key in ALL_FIELDS: for data_dict in data_dicts: data_dict.setdefault(key, None) utils_np._check_valid_sets_of_keys(data_dicts) # pylint: disable=protected-access with tf.name_scope(name): data_dicts = _to_compatible_data_dicts(data_dicts) return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
def get_graph(input_graphs, index): """Indexes into a graph. Given a `graphs.GraphsTuple` containing arrays and an index (either an `int` or a `slice`), index into the nodes, edges and globals to extract the graphs specified by the slice, and returns them into an another instance of a `graphs.GraphsTuple` containing `Tensor`s. Args: input_graphs: A `graphs.GraphsTuple` containing numpy arrays. index: An `int` or a `slice`, to index into `graph`. `index` should be compatible with the number of graphs in `graphs`. Returns: A `graphs.GraphsTuple` containing numpy arrays, made of the extracted graph(s). Raises: TypeError: if `index` is not an `int` or a `slice`. """ if isinstance(index, int): graph_slice = slice(index, index + 1) elif isinstance(index, slice): graph_slice = index else: raise TypeError("unsupported type: %s" % type(index)) data_dicts = graphs_tuple_to_data_dicts(input_graphs)[graph_slice] return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
def _build_placeholders_from_specs(dtypes, shapes, force_dynamic_num_graphs=True): """Creates a `graphs.GraphsTuple` of placeholders with `dtypes` and `shapes`. The dtypes and shapes arguments are instances of `graphs.GraphsTuple` that contain dtypes and shapes, or `None` values for the fields for which no placeholder should be created. The leading dimension the nodes and edges are dynamic because the numbers of nodes and edges can vary. If `force_dynamic_num_graphs` is True, then the number of graphs is assumed to be dynamic and all fields leading dimensions are set to `None`. If `force_dynamic_num_graphs` is False, then the `GRAPH_NUMBER_FIELDS` leading dimensions are statically defined. Args: dtypes: A `graphs.GraphsTuple` that contains `tf.dtype`s or `None`s. shapes: A `graphs.GraphsTuple` that contains `list`s of integers, `tf.TensorShape`s, or `None`s. force_dynamic_num_graphs: A `bool` that forces the batch dimension to be dynamic. Defaults to `True`. Returns: A `graphs.GraphsTuple` containing placeholders. Raises: ValueError: The `None` fields in `dtypes` and `shapes` do not match. """ dct = {} for field in ALL_FIELDS: dtype = getattr(dtypes, field) shape = getattr(shapes, field) if dtype is None or shape is None: if not (shape is None and dtype is None): raise ValueError( "only one of dtype and shape are None for field {}".format(field)) dct[field] = None elif not shape: raise ValueError("Shapes must have at least rank 1") else: shape = list(shape) if field in GRAPH_DATA_FIELDS or force_dynamic_num_graphs: shape[0] = None dct[field] = tf.placeholder(dtype, shape=shape, name=field) return graphs.GraphsTuple(**dct)
def data_dicts_to_graphs_tuple(data_dicts): """Constructs a `graphs.GraphsTuple` from an iterable of data dicts. The graphs represented by the `data_dicts` argument are batched to form a single instance of `graphs.GraphsTuple` containing numpy arrays. Args: data_dicts: An iterable of dictionaries with keys `GRAPH_DATA_FIELDS`, plus, potentially, a subset of `GRAPH_NUMBER_FIELDS`. The NODES and EDGES fields should be numpy arrays of rank at least 2, while the RECEIVERS, SENDERS are numpy arrays of rank 1 and same dimension as the EDGES field first dimension. The GLOBALS field is a numpy array of rank at least 1. Returns: An instance of `graphs.GraphsTuple` containing numpy arrays. The `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to `np.int32` type. """ for key in graphs.GRAPH_DATA_FIELDS: for data_dict in data_dicts: data_dict.setdefault(key, None) _check_valid_sets_of_keys(data_dicts) data_dicts = _to_compatible_data_dicts(data_dicts) return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
def get_graph(input_graphs, index, name="get_graph"): """Indexes into a graph. Given a `graphs.graphsTuple` containing `Tensor`s and an index (either an `int` or a `slice`), index into the nodes, edges and globals to extract the graphs specified by the slice, and returns them into an another instance of a `graphs.graphsTuple` containing `Tensor`s. Args: input_graphs: A `graphs.GraphsTuple` containing `Tensor`s. index: An `int` or a `slice`, to index into `graph`. `index` should be compatible with the number of graphs in `graphs`. name: (string, optional) A name for the operation. Returns: A `graphs.GraphsTuple` containing `Tensor`s, made of the extracted graph(s). Raises: TypeError: if `index` is not an `int` or a `slice`. """ def safe_slice_none(value, slice_): if value is None: return value return value[slice_] if isinstance(index, int): graph_slice = slice(index, index + 1) elif isinstance(index, slice): graph_slice = index else: raise TypeError("unsupported type: %s" % type(index)) start_slice = slice(0, graph_slice.start) with tf.name_scope(name): start_node_index = tf.reduce_sum( input_graphs.n_node[start_slice], name="start_node_index") start_edge_index = tf.reduce_sum( input_graphs.n_edge[start_slice], name="start_edge_index") end_node_index = start_node_index + tf.reduce_sum( input_graphs.n_node[graph_slice], name="end_node_index") end_edge_index = start_edge_index + tf.reduce_sum( input_graphs.n_edge[graph_slice], name="end_edge_index") nodes_slice = slice(start_node_index, end_node_index) edges_slice = slice(start_edge_index, end_edge_index) sliced_graphs_dict = {} for field in set(GRAPH_NUMBER_FIELDS) | {"globals"}: sliced_graphs_dict[field] = safe_slice_none( getattr(input_graphs, field), graph_slice) field = "nodes" sliced_graphs_dict[field] = safe_slice_none( getattr(input_graphs, field), nodes_slice) for field in {"edges", "senders", "receivers"}: sliced_graphs_dict[field] = safe_slice_none( getattr(input_graphs, field), edges_slice) if (field in {"senders", "receivers"} and sliced_graphs_dict[field] is not None): sliced_graphs_dict[field] = sliced_graphs_dict[field] - start_node_index return graphs.GraphsTuple(**sliced_graphs_dict)