예제 #1
0
def get_graph_tuple(nodes, globals=None):
    """
	Helper function to create a dict w/ relevant fields for graph_data
	:param nodes: (Tensor) batch of node values (batch, num_kpts, node_feature_dims)
	:param globals: (Tensor) batch of global values (batch, 1, global_feature_dims)
	:return:
		graph_tuple: graph.GraphTuple type
	"""
    nodes_shape = nodes.shape
    batch_size = nodes_shape[0]
    num_nodes = tf.ones([batch_size], dtype=tf.int32)
    num_edges = tf.ones([batch_size], dtype=tf.int32)
    # defining num_nodes & num_edges for each sample in a batch of input graphs
    b_num_nodes = nodes_shape[1] * num_nodes
    b_num_edges = (nodes_shape[1]**2) * num_edges
    # rehaping (b, num_nodes, dims) -> (b*num_nodes, dims)
    nodes = tf.reshape(nodes,
                       [nodes_shape[0] * nodes_shape[1], nodes_shape[2]])
    if globals is not None:
        globals_shape = globals.shape
        globals = tf.reshape(
            globals, [globals_shape[0] * globals_shape[1], globals_shape[2]])
        graph_tuple = graphs.GraphsTuple(nodes=nodes, globals=globals)
    else:
        graph_tuple = graphs.GraphsTuple(nodes=nodes,
                                         globals=None,
                                         edges=None,
                                         n_node=b_num_nodes,
                                         n_edge=b_num_edges,
                                         senders=None,
                                         receivers=None)
    return graph_tuple
예제 #2
0
def dtype_shape_from_graphs_tuple(input_graph, with_batch_dim=False,\
                                with_padding=True, debug=False, with_fixed_size=False):
    graphs_tuple_dtype = {}
    graphs_tuple_shape = {}

    edge_dim_fields = [graphs.EDGES, graphs.SENDERS, graphs.RECEIVERS]
    for field_name in graphs.ALL_FIELDS:
        field_sample = getattr(input_graph, field_name)
        shape = list(field_sample.shape)
        dtype = field_sample.dtype
        print(field_name, shape, dtype)

        if not with_fixed_size and shape and not with_padding:
            if with_batch_dim:
                shape[1] = None
            else:
                if field_name == graphs.NODES or field_name in edge_dim_fields:
                    shape[0] = None

        graphs_tuple_dtype[field_name] = dtype
        graphs_tuple_shape[field_name] = tf.TensorShape(shape)
        if debug:
            print(field_name, shape, dtype)

    return graphs.GraphsTuple(**graphs_tuple_dtype), graphs.GraphsTuple(
        **graphs_tuple_shape)
예제 #3
0
파일: graph.py 프로젝트: xju2/root_gnn
def parse_tfrec_function(example_proto):
    features_description = dict(
        [(key+"_IN",  tf.io.FixedLenFeature([], tf.string)) for key in graphs.ALL_FIELDS] + 
        [(key+"_OUT", tf.io.FixedLenFeature([], tf.string)) for key in graphs.ALL_FIELDS])

    example = tf.io.parse_single_example(example_proto, features_description)
    input_dd = graphs.GraphsTuple(**dict([(key, tf.io.parse_tensor(example[key+"_IN"], graph_types[key]))
        for key in graphs.ALL_FIELDS]))
    out_dd = graphs.GraphsTuple(**dict([(key, tf.io.parse_tensor(example[key+"_OUT"], graph_types[key]))
        for key in graphs.ALL_FIELDS]))
    return input_dd, out_dd
 def test_map_field_default_value(self):
   """Tests the default value for the `fields` argument."""
   graph = graphs.GraphsTuple(**self.graph)
   mapped_fields = []
   graph = graph.map(mapped_fields.append)
   self.assertListEqual(sorted(mapped_fields),
                        sorted([graphs.EDGES, graphs.GLOBALS, graphs.NODES]))
예제 #5
0
 def object_encoder_graph(self, encodings_per_object):
     """Function to make a graph object from the object oriented encoding
     Inputs:
         encodings_per_object: tensor containing the embeddings per object slot (size : [batch, n_objects, embedding dim])
     Outputs:
         graph_representation: fully conected graph with as node attributes the object encodings (size : [batch, graph])
     """
     # Specify the number of nodes and edges implied by to the passed n_objects and the encodigs for each object
     n_nodes = tf.tile(tf.constant([self.n_objects]),
                       tf.shape(encodings_per_object)[0:1],
                       name='n_nodes')
     n_edges = tf.tile(tf.constant([0]),
                       tf.shape(encodings_per_object)[0:1],
                       name='n_edges')
     # put the node_attributes in the correct shape to make graph object:
     node_attributes = tf.reshape(encodings_per_object, [
         tf.shape(encodings_per_object)[0] *
         tf.shape(encodings_per_object)[1], self.state_dim_embedding
     ])
     # make graph object with specified node attributes:
     graph = graphs.GraphsTuple(nodes=node_attributes,
                                edges=None,
                                globals=None,
                                receivers=None,
                                senders=None,
                                n_node=n_nodes,
                                n_edge=n_edges)
     # Connect all node to the other nodes (i.e. make fully connected):
     fully_connected_graph = utils_tf.fully_connect_graph_dynamic(
         graph, exclude_self_edges=False)
     # Make it runnable in TF because None's are used:
     runnable_fc_graph = utils_tf.make_runnable_in_session(
         fully_connected_graph)
     return runnable_fc_graph
예제 #6
0
    def _build(self, input_op):
        receivers = input_op.receivers
        senders = input_op.senders
        n_node = input_op.n_node
        n_edge = input_op.n_edge

        latent = self._encoder(input_op)
        output_ops = []

        for i in range(self._num_processing_steps):
            for j in range(self._proc_hops[i]):
                latent = self._core(latent)

            decoded_op = self._decoder(latent)
            output_ops.append(decoded_op)

        stacked_edges = tf.stack([g.edges for g in output_ops], axis=1)
        stacked_nodes = tf.stack([g.nodes for g in output_ops], axis=1)
        stacked_globals = tf.stack([g.globals for g in output_ops], axis=1)

        stacked_globals = tf.reshape(stacked_globals, (-1, self._n_stacked))
        stacked_edges = tf.reshape(stacked_edges, (-1, self._n_stacked))
        stacked_nodes = tf.reshape(stacked_nodes, (-1, self._n_stacked))

        feature_graph = graphs.GraphsTuple(nodes=stacked_nodes,
                                           edges=stacked_edges,
                                           globals=stacked_globals,
                                           receivers=receivers,
                                           senders=senders,
                                           n_node=n_node,
                                           n_edge=n_edge)

        return self._output_transform(feature_graph)
예제 #7
0
    def generator():
        labeled_indices = arrays[f"{prefix}_indices"]
        if ratio_unlabeled_data_to_labeled_data > 0:
            num_unlabeled_data_to_add = int(
                ratio_unlabeled_data_to_labeled_data *
                labeled_indices.shape[0])
            unlabeled_indices = np.random.choice(
                NUM_PAPERS, size=num_unlabeled_data_to_add, replace=False)
            root_node_indices = np.concatenate(
                [labeled_indices, unlabeled_indices])
        else:
            root_node_indices = labeled_indices
        if shuffle_indices:
            root_node_indices = root_node_indices.copy()
            np.random.shuffle(root_node_indices)

        for index in root_node_indices:
            graph = sub_sampler.subsample_graph(
                index,
                arrays["author_institution_index"],
                arrays["institution_author_index"],
                arrays["author_paper_index"],
                arrays["paper_author_index"],
                arrays["paper_paper_index"],
                arrays["paper_paper_index_t"],
                paper_years=arrays["paper_year"],
                max_nodes=max_nodes,
                max_edges=max_edges,
                **subsampler_kwargs)

            graph = add_nodes_label(graph, arrays["paper_label"])
            graph = add_nodes_year(graph, arrays["paper_year"])
            graph = tf_graphs.GraphsTuple(*graph)
            yield graph
예제 #8
0
def get_graph(input_graphs, index):
    """Indexes into a graph.

    Given a `graphs.GraphsTuple` containing arrays and an index (either
    an `int` or a `slice`), index into the nodes, edges and globals to extract the
    graphs specified by the slice, and returns them into an another instance of a
    `graphs.GraphsTuple` containing `Tensor`s.

    Args:
      input_graphs: A `graphs.GraphsTuple` containing numpy arrays.
      index: An `int` or a `slice`, to index into `graph`. `index` should be
        compatible with the number of graphs in `graphs`.

    Returns:
      A `graphs.GraphsTuple` containing numpy arrays, made of the extracted
        graph(s).

    Raises:
      TypeError: if `index` is not an `int` or a `slice`.
    """
    if isinstance(index, int):
        graph_slice = slice(index, index + 1)
    elif isinstance(index, slice):
        graph_slice = index
    else:
        raise TypeError("unsupported type: %s" % type(index))
    data_dicts = graphs_tuple_to_data_dicts(input_graphs)[graph_slice]
    return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
예제 #9
0
def data_dicts_to_graphs_tuple(data_dicts, name="data_dicts_to_graphs_tuple"):
    """Creates a `graphs.GraphsTuple` containing tensors from data dicts.

   All dictionaries must have exactly the same set of keys with non-`None`
   values associated to them. Moreover, this set of this key must define a valid
   graph (i.e. if the `EDGES` are `None`, the `SENDERS` and `RECEIVERS` must be
   `None`, and `SENDERS` and `RECEIVERS` can only be `None` both at the same
   time). The values associated with a key must be convertible to `Tensor`s,
   for instance python lists, numpy arrays, or Tensorflow `Tensor`s.

   This method may perform a memory copy.

   The `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to
   `np.int32` type.

  Args:
    data_dicts: An iterable of data dictionaries with keys in `ALL_FIELDS`.
    name: (string, optional) A name for the operation.

  Returns:
    A `graphs.GraphTuple` representing the graphs in `data_dicts`.
  """
    for key in ALL_FIELDS:
        for data_dict in data_dicts:
            data_dict.setdefault(key, None)
    utils_np._check_valid_sets_of_keys(data_dicts)  # pylint: disable=protected-access
    with tf.name_scope(name):
        data_dicts = _to_compatible_data_dicts(data_dicts)
        return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
예제 #10
0
def data_dicts_to_graphs_tuple(data_dicts):
    """Constructs a `graphs.GraphsTuple` from an iterable of data dicts.

    The graphs represented by the `data_dicts` argument are batched to form a
    single instance of `graphs.GraphsTuple` containing numpy arrays.

    Args:
      data_dicts: An iterable of dictionaries with keys `GRAPH_DATA_FIELDS`, plus,
        potentially, a subset of `GRAPH_NUMBER_FIELDS`. The NODES and EDGES fields
        should be numpy arrays of rank at least 2, while the RECEIVERS, SENDERS
        are numpy arrays of rank 1 and same dimension as the EDGES field first
        dimension. The GLOBALS field is a numpy array of rank at least 1.

    Returns:
      An instance of `graphs.GraphsTuple` containing numpy arrays. The
      `RECEIVERS`, `SENDERS`, `N_NODE` and `N_EDGE` fields are cast to `np.int32`
      type.
    """
    data_dicts = [dict(d) for d in data_dicts]
    for key in graphs.GRAPH_DATA_FIELDS:
        for data_dict in data_dicts:
            data_dict.setdefault(key, None)
    _check_valid_sets_of_keys(data_dicts)
    data_dicts = _to_compatible_data_dicts(data_dicts)
    return graphs.GraphsTuple(**_concatenate_data_dicts(data_dicts))
예제 #11
0
  def test_self_attention(self):
    # Just one feature per node.
    values_np = np.arange(sum(self.N_NODE)) + 1.
    # Multiple heads, one positive values, one negative values.
    values_np = np.stack([values_np, values_np*-1.], axis=-1)
    # Multiple features per node, per head, at different scales.
    values_np = np.stack([values_np, values_np*0.1], axis=-1)
    values = tf.constant(values_np, dtype=tf.float32)

    keys_np = [
        [[0.3, 0.4]]*2,  # Irrelevant (only sender to one node)
        [[0.1, 0.5]]*2,  # Not used (is not a sender)
        [[1, 0], [0, 1]],
        [[0, 1], [1, 0]],
        [[1, 1], [1, 1]],
        [[0.4, 0.3]]*2,  # Not used (is not a sender)
        [[0.3, 0.2]]*2]  # Not used (is not a sender)
    keys = tf.constant(keys_np, dtype=tf.float32)

    queries_np = [
        [[0.2, 0.7]]*2,  # Not used (is not a receiver)
        [[0.3, 0.2]]*2,  # Irrelevant (only receives from one node)
        [[0.2, 0.8]]*2,  # Not used (is not a receiver)
        [[0.2, 0.4]]*2,  # Not used (is not a receiver)
        [[0.3, 0.9]]*2,  # Not used (is not a receiver)
        [[0, np.log(2)], [np.log(3), 0]],
        [[np.log(2), 0], [0, np.log(3)]]]
    queries = tf.constant(queries_np, dtype=tf.float32)

    attention_graph = graphs.GraphsTuple(
        nodes=None,
        edges=None,
        globals=None,
        receivers=tf.constant(self.RECEIVERS, dtype=tf.int32),
        senders=tf.constant(self.SENDERS, dtype=tf.int32),
        n_node=tf.constant(self.N_NODE, dtype=tf.int32),
        n_edge=tf.constant(self.N_EDGE, dtype=tf.int32),)

    self_attention = modules.SelfAttention()
    output_graph = self_attention(values, keys, queries, attention_graph)
    mixed_nodes = output_graph.nodes

    with self.test_session() as sess:
      mixed_nodes_output = sess.run(mixed_nodes)

    expected_mixed_nodes = [
        [[0., 0.], [0., 0.]],  # Does not receive any edges
        [[1., 0.1], [-1., -0.1]],  # Only receives from n0.
        [[0., 0.], [0., 0.]],  # Does not receive any edges
        [[0., 0.], [0., 0.]],  # Does not receive any edges
        [[0., 0.], [0., 0.]],  # Does not receive any edges
        [[11/3, 11/3*0.1],  # Head one, receives from n2(1/3) n3(2/3)
         [-15/4, -15/4*0.1]],  # Head two, receives from n2(1/4) n3(3/4)
        [[20/5, 20/5*0.1],   # Head one, receives from n2(2/5) n3(1/5) n4(2/5)
         [-28/7, -28/7*0.1]],  # Head two, receives from n2(3/7) n3(1/7) n4(3/7)
    ]

    self.assertAllClose(expected_mixed_nodes, mixed_nodes_output)
예제 #12
0
 def test_map_fields_as_expected(self, fields_to_map):
     """Tests that the fields are mapped are as expected."""
     graph = graphs.GraphsTuple(**self.graph)
     graph = graph.map(lambda v: v + v, fields_to_map)
     for field in graphs.ALL_FIELDS:
         if field in fields_to_map:
             self.assertEqual(field + field, getattr(graph, field))
         else:
             self.assertEqual(field, getattr(graph, field))
예제 #13
0
 def test_replace_with_valid_none_fields(self, none_fields):
     # Create a graph with different values.
     graph = graphs.GraphsTuple(**{k: v + v for k, v in self.graph.items()})
     # Update with a graph containing the initial values, or Nones.
     for none_field in none_fields:
         self.graph[none_field] = None
     graph = graph.replace(**self.graph)
     for k, v in self.graph.items():
         self.assertEqual(v, getattr(graph, k))
 def test_map_field_called_only_once(self):
   """Tests that the mapping function is called exactly once per field."""
   graph = graphs.GraphsTuple(**self.graph)
   mapped_fields = []
   def map_fn(v):
     mapped_fields.append(v)
     return v
   graph = graph.map(map_fn, graphs.ALL_FIELDS)
   self.assertListEqual(sorted(mapped_fields), sorted(graphs.ALL_FIELDS))
예제 #15
0
 def _get_shaped_input_graph(self):
   return graphs.GraphsTuple(
       nodes=tf.zeros([3, 4, 5, 11], dtype=tf.float32),
       edges=tf.zeros([5, 4, 5, 12], dtype=tf.float32),
       globals=tf.zeros([2, 4, 5, 13], dtype=tf.float32),
       receivers=tf.range(5, dtype=tf.int32) // 3,
       senders=tf.range(5, dtype=tf.int32) % 3,
       n_node=tf.constant([2, 1], dtype=tf.int32),
       n_edge=tf.constant([3, 2], dtype=tf.int32),
   )
예제 #16
0
 def _get_shaped_input_graph(self):
     return graphs.GraphsTuple(
         nodes=torch.zeros([3, 4, 5, 11], dtype=torch.float32),
         edges=torch.zeros([5, 4, 5, 12], dtype=torch.float32),
         globals=torch.zeros([2, 4, 5, 13], dtype=torch.float32),
         receivers=torch.range(0, 5, dtype=torch.int64) // 3,
         senders=torch.range(0, 5, dtype=torch.int64) % 3,
         n_node=torch.tensor([2, 1], dtype=torch.int64),
         n_edge=torch.tensor([3, 2], dtype=torch.int64),
     )
예제 #17
0
 def _get_graphs_tuple(self):
   """Returns a GraphsTuple containing a graph based on the test system."""
   return graphs.GraphsTuple(
       nodes=tf.constant(self._nodes, dtype=tf.float32),
       edges=tf.constant(self._edges, dtype=tf.float32),
       globals=tf.constant(np.array([[0.0]]), dtype=tf.float32),
       receivers=tf.constant(self._receivers, dtype=tf.int32),
       senders=tf.constant(self._senders, dtype=tf.int32),
       n_node=tf.constant([len(self._nodes)], dtype=tf.int32),
       n_edge=tf.constant([len(self._edges)], dtype=tf.int32))
예제 #18
0
def specs_from_graphs_tuple(
    graphs_tuple_sample,
    with_batch_dim=False,
    dynamic_num_graphs=False,
    dynamic_num_nodes=True,
    dynamic_num_edges=True,
    description_fn=tf.TensorSpec,
):
    graphs_tuple_description_fields = {}
    edge_dim_fields = [graphs.EDGES, graphs.SENDERS, graphs.RECEIVERS]

    for field_name in graphs.ALL_FIELDS:
        field_sample = getattr(graphs_tuple_sample, field_name)
        if field_sample is None:
            raise ValueError(
                "The `GraphsTuple` field `{}` was `None`. All fields of the "
                "`GraphsTuple` must be specified to create valid signatures that"
                "work with `tf.function`. This can be achieved with `input_graph = "
                "utils_tf.set_zero_{{node,edge,global}}_features(input_graph, 0)`"
                "to replace None's by empty features in your graph. Alternatively"
                "`None`s can be replaced by empty lists by doing `input_graph = "
                "input_graph.replace({{nodes,edges,globals}}=[]). To ensure "
                "correct execution of the program, it is recommended to restore "
                "the None's once inside of the `tf.function` by doing "
                "`input_graph = input_graph.replace({{nodes,edges,globals}}=None)"
                "".format(field_name))

        shape = list(field_sample.shape)
        dtype = field_sample.dtype

        # If the field is not None but has no field shape (i.e. it is a constant)
        # then we consider this to be a replaced `None`.
        # If dynamic_num_graphs, then all fields have a None first dimension.
        # If dynamic_num_nodes, then the "nodes" field needs None first dimension.
        # If dynamic_num_edges, then the "edges", "senders" and "receivers" need
        # a None first dimension.
        if shape:
            if with_batch_dim:
                shape[1] = None
            elif (dynamic_num_graphs \
                or (dynamic_num_nodes \
                    and field_name == graphs.NODES) \
                or (dynamic_num_edges \
                    and field_name in edge_dim_fields)):
                shape[0] = None

        print(field_name, shape, dtype)
        graphs_tuple_description_fields[field_name] = description_fn(
            shape=shape, dtype=dtype)

    return graphs.GraphsTuple(**graphs_tuple_description_fields)
 def get_placeholders(self):
     # TODO: Make this work with utils_tf for graph_nets for GraphTuple
     # Build placeholders
     placeholders = {}
     for k, v in self.features.items():
         use_batch = (k not in GRAPH_KEYS)
         placeholders.update(v.get_placeholder(batch=use_batch))
     # Other placeholders
     other_keys = (set(self.features.keys()) - set(GRAPH_KEYS))
     sample = {k: placeholders[k] for k in other_keys}
     # Build Graph Tuple
     sample['graph'] = \
       graphs.GraphsTuple(**{ k: placeholders[k] for k in GRAPH_KEYS })
     return sample, placeholders
예제 #20
0
 def test_build_placeholders_from_specs(self,
                                        none_fields,
                                        force_dynamic_num_graphs=False):
   num_graphs = 3
   shapes = graphs.GraphsTuple(
       nodes=[3, 4],
       edges=[2],
       globals=[num_graphs, 4, 6],
       receivers=[None],
       senders=[18],
       n_node=[num_graphs],
       n_edge=[num_graphs],
   )
   dtypes = graphs.GraphsTuple(
       nodes=tf.float64,
       edges=tf.int32,
       globals=tf.float32,
       receivers=tf.int64,
       senders=tf.int64,
       n_node=tf.int32,
       n_edge=tf.int64)
   dtypes = dtypes.map(lambda _: None, none_fields)
   shapes = shapes.map(lambda _: None, none_fields)
   placeholders = utils_tf._build_placeholders_from_specs(
       dtypes, shapes, force_dynamic_num_graphs=force_dynamic_num_graphs)
   for k in graphs.ALL_FIELDS:
     placeholder = getattr(placeholders, k)
     if k in none_fields:
       self.assertEqual(None, placeholder)
     else:
       self.assertEqual(getattr(dtypes, k), placeholder.dtype)
       if k not in ["n_node", "n_edge", "globals"] or force_dynamic_num_graphs:
         self.assertAllEqual([None] + getattr(shapes, k)[1:],
                             placeholder.shape.as_list())
       else:
         self.assertAllEqual([num_graphs] + getattr(shapes, k)[1:],
                             placeholder.shape.as_list())
예제 #21
0
def _build_placeholders_from_specs(dtypes,
                                   shapes,
                                   force_dynamic_num_graphs=True):
    """Creates a `graphs.GraphsTuple` of placeholders with `dtypes` and `shapes`.

  The dtypes and shapes arguments are instances of `graphs.GraphsTuple` that
  contain dtypes and shapes, or `None` values for the fields for which no
  placeholder should be created. The leading dimension the nodes and edges are
  dynamic because the numbers of nodes and edges can vary.
  If `force_dynamic_num_graphs` is True, then the number of graphs is assumed to
  be dynamic and all fields leading dimensions are set to `None`.
  If `force_dynamic_num_graphs` is False, then `N_NODE`, `N_EDGE` and `GLOBALS`
  leading dimensions are statically defined.

  Args:
    dtypes: A `graphs.GraphsTuple` that contains `tf.dtype`s or `None`s.
    shapes: A `graphs.GraphsTuple` that contains `list`s of integers,
      `tf.TensorShape`s, or `None`s.
    force_dynamic_num_graphs: A `bool` that forces the batch dimension to be
      dynamic. Defaults to `True`.

  Returns:
    A `graphs.GraphsTuple` containing placeholders.

  Raises:
    ValueError: The `None` fields in `dtypes` and `shapes` do not match.
  """
    dct = {}
    for field in ALL_FIELDS:
        dtype = getattr(dtypes, field)
        shape = getattr(shapes, field)
        if dtype is None or shape is None:
            if not (shape is None and dtype is None):
                raise ValueError(
                    "only one of dtype and shape are None for field {}".format(
                        field))
            dct[field] = None
        elif not shape:
            raise ValueError("Shapes must have at least rank 1")
        else:
            shape = list(shape)
            if field not in [N_NODE, N_EDGE, GLOBALS
                             ] or force_dynamic_num_graphs:
                shape[0] = None
            dct[field] = tf.compat.v1.placeholder(dtype,
                                                  shape=shape,
                                                  name=field)

    return graphs.GraphsTuple(**dct)
예제 #22
0
 def get_placeholder_and_feature(self, batch):
     del batch  # We do not need batch
     placeholders = {}
     sample_dict = {}
     for key, feat in self.features.items():
         # Due to how graphs are concatenated we only need batch dimension for
         # globals, n_node, and n_edge
         batch = False
         if key in ['globals', 'n_node', 'n_edge']:
             batch = True
         ph, val = feat.get_placeholder_and_feature(batch=batch)
         placeholders.update(ph)
         sample_dict[key] = val
     sample = graphs.GraphsTuple(**sample_dict)
     return placeholders, sample
예제 #23
0
def make_graph_from_static_structure(
    positions,
    types,
    box,
    edge_threshold):
  """
  Returns graph representing the structure of the collective.

  Each particle is represented by a node in the graph. The particle motion direction is  stored as a node feature.
  Two particles at a distance less than the threshold are connected by an edge.
  The relative distance vector is stored as an edge feature.

  Args:
    positions: particle positions with shape [n_particles, 2].
    types: particle motion direction with shape [n_particles].
    box: dimensions of the box that contains the particles with shape [2].
    edge_threshold: particles at distance less than threshold are connected by a
  """
  # Calculate pairwise relative distances between particles
  cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :]
  # Enforces periodic boundary conditions.
  box_ = box[tf.newaxis, tf.newaxis, :]
  cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_
  cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_
  # Calculates adjacency matrix in a sparse format (indices), based on the given
  # distances and threshold.
  distances = tf.norm(cross_positions, axis=-1)
  indices = tf.where(distances < edge_threshold)
 
  # Defines graph.
  nodes = types[:, tf.newaxis]
  senders = indices[:, 0]
  receivers = indices[:, 1]
  edges = tf.gather_nd(cross_positions, indices)
  va=[[0.03]]
  return graphs.GraphsTuple(
      nodes=tf.cast(nodes, tf.float32),
      n_node=tf.reshape(tf.shape(nodes)[0], [1]),
      edges=tf.cast(edges, tf.float32),
      n_edge=tf.reshape(tf.shape(edges)[0], [1]),
      globals=tf.cast(va, dtype=tf.float32),

      receivers=tf.cast(receivers, tf.int32),
      senders=tf.cast(senders, tf.int32)
      )
예제 #24
0
  def test_received_edges_normalizer(self, logits,
                                     expected_normalized, normalizer):
    graph = graphs.GraphsTuple(
        nodes=None,
        edges=logits,
        globals=None,
        receivers=tf.constant(self.RECEIVERS, dtype=tf.int32),
        senders=tf.constant(self.SENDERS, dtype=tf.int32),
        n_node=tf.constant(self.N_NODE, dtype=tf.int32),
        n_edge=tf.constant(self.N_EDGE, dtype=tf.int32),
    )
    actual_normalized_edges = modules._received_edges_normalizer(
        graph, normalizer)

    with self.test_session() as sess:
      actual_normalized_edges_output = sess.run(actual_normalized_edges)

    self.assertAllClose(expected_normalized, actual_normalized_edges_output)
예제 #25
0
파일: graph.py 프로젝트: rkunnawa/root_gnn
def get_signature(graphs_tuple_sample):
    from graph_nets import graphs

    graphs_tuple_description_fields = {}

    for field_name in graphs.ALL_FIELDS:
        per_replica_sample = getattr(graphs_tuple_sample, field_name)
        def spec_from_value(v):
            shape = list(v.shape)
            dtype = list(v.dtype)
            if shape:
                shape[1] = None
            return tf.TensorSpec(shape=shape, dtype=dtype)

        per_replica_spec = tf.distribute.values.PerReplicaSpec(
            *(spec_from_value(v) for v in per_replica_sample.values)
        )

        graphs_tuple_description_fields[field_name] = per_replica_spec
    return graphs.GraphsTuple(**graphs_tuple_description_fields)
예제 #26
0
파일: graph.py 프로젝트: rkunnawa/root_gnn
def _concat_batch_dim(G):
    """
    G is a GraphNtuple Tensor, with additional dimension for batch-size.
    Concatenate them along the axis for batch
    """
    input_graphs = []
    for ibatch in [0, 1]:
        data_dict = {
            "nodes": G.nodes[ibatch],
            "edges": G.edges[ibatch],
            "receivers": G.receivers[ibatch],
            'senders': G.senders[ibatch],
            'globals': G.globals[ibatch],
            'n_node': G.n_node[ibatch],
            'n_edge': G.n_edge[ibatch],
        }
        input_graphs.append(graphs.GraphsTuple(**data_dict))
        return (tf.add(ibatch, 1), input_graphs)
    print("{} graphs".format(len(input_graphs)))
    return utils_tf.concat(input_graphs, axis=0)
예제 #27
0
    def __call__(self, graph, graph_state, graph_mask, edge_kw={}, node_kw={}, global_kw={}):
        num_of_times = graph.nodes.shape[1]
        for t in range(num_of_times):
            graph_t = graphs.GraphsTuple(
                tf.squeeze(graph.nodes[:, t, :]),
                tf.squeeze(graph.edges[:, t, :]),
                tf.squeeze(graph.globals[:, t, :]))
            graph_mask_t = GraphMasks(
                tf.squeeze(graph_mask.nodes[:, t, :]),
                tf.squeeze(graph_mask.edges[:, t, :]),
                tf.squeeze(graph_mask.globals[:, t, :]))

            graph_output, edge_state = self._edge_block(
                graph_t, graph_state.edges, graph_mask_t.edges, **edge_kw)
            graph_output, node_state = self._node_block(
                graph_output, graph_state.nodes, graph_mask_t.nodes, **node_kw)
            graph_output, global_state = self._global_block(
                graph_output, graph_state.globals, graph_mask_t.globals, **global_kw)

            graph_state = GraphStates(edge_state, node_state, global_state)
        return graph_output, graph_state
예제 #28
0
  def populate_test_data(self, max_size):
    """Populates the class fields with data used for the tests.

    This creates a batch of graphs with number of nodes from 0 to `num`,
    number of edges ranging from 1 to `num`, plus an empty graph with no nodes
    and no edges (so that the total number of graphs is 1 + (num ** (num + 1)).

    The nodes states, edges states and global states of the graphs are
    created to have different types and shapes.

    Those graphs are stored both as dictionaries (in `self.graphs_dicts_in`,
    without `n_node` and `n_edge` information, and in `self.graphs_dicts_out`
    with these two fields filled), and a corresponding numpy
    `graphs.GraphsTuple` is stored in `self.reference_graph`.

    Args:
      max_size: The maximum number of nodes and edges (inclusive).
    """
    filt = lambda x: (x[0] > 0) or (x[1] == 0)
    n_node, n_edge = zip(*list(
        filter(filt, itertools.product(
            range(max_size + 1), range(max_size + 1)))))

    graphs_dicts = []
    nodes = []
    edges = []
    receivers = []
    senders = []
    globals_ = []

    def _make_default_state(shape, dtype):
      return np.arange(np.prod(shape)).reshape(shape).astype(dtype)

    for i, (n_node_, n_edge_) in enumerate(zip(n_node, n_edge)):
      n = _make_default_state([n_node_, 7, 11], "f4") + i * 100.
      e = _make_default_state([n_edge_, 13, 14], np.float64) + i * 100. + 1000.
      r = _make_default_state([n_edge_], np.int32) % n_node[i]
      s = (_make_default_state([n_edge_], np.int32) + 1) % n_node[i]
      g = _make_default_state([5, 3], "f4") - i * 100. - 1000.

      nodes.append(n)
      edges.append(e)
      receivers.append(r)
      senders.append(s)
      globals_.append(g)
      graphs_dict = dict(nodes=n, edges=e, receivers=r, senders=s, globals=g)
      graphs_dicts.append(graphs_dict)

    # Graphs dicts without n_node / n_edge (to be used as inputs).
    self.graphs_dicts_in = graphs_dicts
    # Graphs dicts with n_node / n_node (to be checked against outputs).
    self.graphs_dicts_out = []
    for dict_ in self.graphs_dicts_in:
      completed_dict = dict_.copy()
      completed_dict["n_node"] = completed_dict["nodes"].shape[0]
      completed_dict["n_edge"] = completed_dict["edges"].shape[0]
      self.graphs_dicts_out.append(completed_dict)

    # pylint: disable=protected-access
    offset = utils_np._compute_stacked_offsets(n_node, n_edge)
    # pylint: enable=protected-access
    self.reference_graph = graphs.GraphsTuple(**dict(
        nodes=np.concatenate(nodes, axis=0),
        edges=np.concatenate(edges, axis=0),
        receivers=np.concatenate(receivers, axis=0) + offset,
        senders=np.concatenate(senders, axis=0) + offset,
        globals=np.stack(globals_),
        n_node=np.array(n_node),
        n_edge=np.array(n_edge)))
예제 #29
0
  def test_correct_signature(
      self,
      dynamic_num_nodes,
      dynamic_num_edges,
      dynamic_num_graphs,
      batched,
      replace_globals_with_constant):
    """Tests that the correct spec is created when using different options."""

    if batched:
      input_data_dicts = [self.graphs_dicts[1], self.graphs_dicts[2]]
    else:
      input_data_dicts = [self.graphs_dicts[1]]

    graph = utils_np.data_dicts_to_graphs_tuple(input_data_dicts)
    num_graphs = len(input_data_dicts)
    num_edges = sum(graph.n_edge).item()
    num_nodes = sum(graph.n_node).item()

    # Manually setting edges and globals fields to give some variety in
    # testing situations.
    # Making edges have rank 1 to .
    graph = graph.replace(edges=np.zeros(num_edges))

    # Make a constant field.
    if replace_globals_with_constant:
      graph = graph.replace(globals=np.array(0.0, dtype=np.float32))

    spec_signature = utils_tf.specs_from_graphs_tuple(
        graph, dynamic_num_graphs, dynamic_num_nodes, dynamic_num_edges)

    # Captures if nodes/edges will be dynamic either due to dynamic nodes/edges
    # or dynamic graphs.
    dynamic_nodes_or_graphs = dynamic_num_nodes or dynamic_num_graphs
    dynamic_edges_or_graphs = dynamic_num_edges or dynamic_num_graphs

    num_edges = None if dynamic_edges_or_graphs else num_edges
    num_nodes = None if dynamic_nodes_or_graphs else num_nodes
    num_graphs = None if dynamic_num_graphs else num_graphs

    if replace_globals_with_constant:
      expected_globals_shape = []
    else:
      expected_globals_shape = [num_graphs,] + test_utils.GLOBALS_DIMS

    expected_answer = graphs.GraphsTuple(
        nodes=tf.TensorSpec(
            shape=[num_nodes,] + test_utils.NODES_DIMS,
            dtype=tf.float32),
        edges=tf.TensorSpec(
            shape=[num_edges],  # Edges were manually replaced to have dim 1.
            dtype=tf.float64),
        n_node=tf.TensorSpec(
            shape=[num_graphs],
            dtype=tf.int32),
        n_edge=tf.TensorSpec(
            shape=[num_graphs],
            dtype=tf.int32),
        globals=tf.TensorSpec(
            shape=expected_globals_shape,
            dtype=tf.float32),
        receivers=tf.TensorSpec(
            shape=[num_edges],
            dtype=tf.int32),
        senders=tf.TensorSpec(
            shape=[num_edges],
            dtype=tf.int32),
        )

    with self.subTest(name="Correct Type."):
      self.assertIsInstance(spec_signature, graphs.GraphsTuple)

    with self.subTest(name="Correct Signature."):
      self.assertAllEqual(spec_signature, expected_answer)
예제 #30
0
def get_graph(input_graphs, index, name="get_graph"):
    """Indexes into a graph.

  Given a `graphs.graphsTuple` containing `Tensor`s and an index (either
  an `int` or a `slice`), index into the nodes, edges and globals to extract the
  graphs specified by the slice, and returns them into an another instance of a
  `graphs.graphsTuple` containing `Tensor`s.

  Args:
    input_graphs: A `graphs.GraphsTuple` containing `Tensor`s.
    index: An `int` or a `slice`, to index into `graph`. `index` should be
      compatible with the number of graphs in `graphs`.
    name: (string, optional) A name for the operation.

  Returns:
    A `graphs.GraphsTuple` containing `Tensor`s, made of the extracted
      graph(s).

  Raises:
    TypeError: if `index` is not an `int` or a `slice`.
  """
    def safe_slice_none(value, slice_):
        if value is None:
            return value
        return value[slice_]

    if isinstance(index, int):
        graph_slice = slice(index, index + 1)
    elif isinstance(index, slice):
        graph_slice = index
    else:
        raise TypeError("unsupported type: %s" % type(index))

    start_slice = slice(0, graph_slice.start)

    with tf.name_scope(name):
        start_node_index = tf.reduce_sum(input_graphs.n_node[start_slice],
                                         name="start_node_index")
        start_edge_index = tf.reduce_sum(input_graphs.n_edge[start_slice],
                                         name="start_edge_index")
        end_node_index = start_node_index + tf.reduce_sum(
            input_graphs.n_node[graph_slice], name="end_node_index")
        end_edge_index = start_edge_index + tf.reduce_sum(
            input_graphs.n_edge[graph_slice], name="end_edge_index")
        nodes_slice = slice(start_node_index, end_node_index)
        edges_slice = slice(start_edge_index, end_edge_index)

        sliced_graphs_dict = {}

        for field in set(GRAPH_NUMBER_FIELDS) | {"globals"}:
            sliced_graphs_dict[field] = safe_slice_none(
                getattr(input_graphs, field), graph_slice)

        field = "nodes"
        sliced_graphs_dict[field] = safe_slice_none(
            getattr(input_graphs, field), nodes_slice)

        for field in {"edges", "senders", "receivers"}:
            sliced_graphs_dict[field] = safe_slice_none(
                getattr(input_graphs, field), edges_slice)
            if (field in {"senders", "receivers"}
                    and sliced_graphs_dict[field] is not None):
                sliced_graphs_dict[
                    field] = sliced_graphs_dict[field] - start_node_index

        return graphs.GraphsTuple(**sliced_graphs_dict)