Ejemplo n.º 1
0
  def test_name_scope(self, name, expected_name):
    kwargs = {"name": name} if name else {}
    expected_name = expected_name if expected_name else name

    t = tf.zeros([3, 2, 4])
    indices = tf.constant([2, 3])
    with test_utils.assert_new_op_prefixes(self, expected_name + "/"):
      utils_tf.repeat(t, indices, axis=1, **kwargs)
Ejemplo n.º 2
0
def broadcast_globals_to_edges(graph,
                               name="broadcast_globals_to_edges",
                               num_edges_hint=None):
    """Broadcasts the global features to the edges of a graph.

  Args:
    graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of
      shape `[n_graphs] + global_shape`, and `N_EDGE` field of shape
      `[n_graphs]`.
    name: (string, optional) A name for the operation.
    num_edges_hint: Integer indicating the total number of edges, if known.

  Returns:
    A tensor of shape `[n_edges] + global_shape`, where
    `n_edges = sum(graph.n_edge)`. The i-th element of this tensor is given by
    `globals[j]`, where j is the index of the graph the i-th edge belongs to
    (i.e. is such that
    `sum_{k < j} graphs.n_edge[k] <= i < sum_{k <= j} graphs.n_edge[k]`).

  Raises:
    ValueError: If either `graph.globals` or `graph.n_edge` is `None`.
  """
    _validate_broadcasted_graph(graph, GLOBALS, N_EDGE)
    with tf.name_scope(name):
        return utils_tf.repeat(graph.globals,
                               graph.n_edge,
                               axis=0,
                               sum_repeats_hint=num_edges_hint)
Ejemplo n.º 3
0
 def __call__(self, graph):
     _validate_graph(graph, (EDGES, ),
                     additional_message="when aggregating from edges.")
     num_graphs = utils_tf.get_num_graphs(graph)
     graph_index = tf.range(num_graphs)
     indices = utils_tf.repeat(graph_index, graph.n_edge, axis=0)
     return self._reducer(graph.edges, indices, num_graphs)
Ejemplo n.º 4
0
 def _build(self, graph):
     _validate_graph(graph, (NODES, ),
                     additional_message="when aggregating from nodes.")
     num_graphs = utils_tf.get_num_graphs(graph)
     graph_index = tf.range(num_graphs)
     indices = utils_tf.repeat(graph_index, graph.n_node, axis=0)
     return self._reducer(graph.nodes, indices, num_graphs)
Ejemplo n.º 5
0
 def test_repeat(self):
   t = np.arange(24).reshape(3, 2, 4)
   tensor = tf.constant(t)
   repeats = [2, 3]
   axis = 1
   expected = np.repeat(t, repeats, axis=axis)
   actual = utils_tf.repeat(tensor, repeats, axis=axis)
   self.assertAllEqual(expected, actual)
Ejemplo n.º 6
0
 def test_repeat(self, shape, repeats, axis):
     num_elements = np.prod(shape)
     t = np.arange(num_elements).reshape(*shape)
     expected = np.repeat(t, repeats, axis=axis)
     tensor = tf.constant(t)
     repeats = tf.constant(repeats, dtype=tf.int32)
     actual = utils_tf.repeat(tensor, repeats, axis=axis)
     self.assertAllEqual(expected, actual)
Ejemplo n.º 7
0
 def test_repeat(self):
   t = np.arange(24).reshape(3, 2, 4)
   tensor = tf.constant(t)
   repeats = [2, 3]
   axis = 1
   expected = np.repeat(t, repeats, axis=axis)
   op = utils_tf.repeat(tensor, repeats, axis=axis)
   with self.test_session() as sess:
     actual = sess.run(op)
   self.assertAllEqual(expected, actual)
Ejemplo n.º 8
0
 def __call__(self, graphs, **kwargs):
     destination = utils_tf.repeat(graphs.globals, graphs.n_edge)
     sender_features = tf.gather(graphs.nodes, graphs.senders)
     receiver_features = tf.gather(graphs.nodes, graphs.receivers)
     edge_features = graphs.edges
     senders = graphs.senders
     n_node = graphs.n_node
     for transformer in self._transformers:
         edge_features = transformer(
             destination,
             sender_features,
             edge_features,
             receiver_features,
             senders,
             n_node,
             **kwargs,
         )
     out_edges = unsorted_segment_softmax(
         self._link_decision(edge_features, **kwargs), senders,
         tf.reduce_sum(n_node))
     return graphs.replace(edges=out_edges)
Ejemplo n.º 9
0
def broadcast_globals_to_nodes_eager(graph):
    """Broadcasts the global features to the nodes of a graph.

    Args:
      graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of
        shape `[n_graphs] + global_shape`, and `N_NODE` field of shape
        `[n_graphs]`.
      name: (string, optional) A name for the operation.

    Returns:
      A tensor of shape `[n_nodes] + global_shape`, where
      `n_nodes = sum(graph.n_node)`. The i-th element of this tensor is given by
      `globals[j]`, where j is the index of the graph the i-th node belongs to
      (i.e. is such that
      `sum_{k < j} graphs.n_node[k] <= i < sum_{k <= j} graphs.n_node[k]`).

    Raises:
      ValueError: If either `graph.globals` or `graph.n_node` is `None`.
    """
    _validate_broadcasted_graph(graph, GLOBALS, N_NODE)
    return utils_tf.repeat(graph.globals, graph.n_node, axis=0)
Ejemplo n.º 10
0
 def map_(self, state, action):
     st = utils_tf.repeat(state, [tf.shape(action)[0]])
     return self.network(tf.concat([st, action], axis=1))
Ejemplo n.º 11
0
def _compute_stacked_offsets(sizes, repeats):
    sizes = tf.cast(tf.convert_to_tensor(sizes[:-1]), tf.int32)
    offset_values = tf.cumsum(tf.concat([[0], sizes], 0))
    return utils_tf.repeat(offset_values, repeats)
Ejemplo n.º 12
0
 def map_(self, state, action):
     st = utils_tf.repeat(state,[tf.shape(action)[0]])
     return self.network(tf.concat([st,action],axis=1))