def test_name_scope(self, name, expected_name): kwargs = {"name": name} if name else {} expected_name = expected_name if expected_name else name t = tf.zeros([3, 2, 4]) indices = tf.constant([2, 3]) with test_utils.assert_new_op_prefixes(self, expected_name + "/"): utils_tf.repeat(t, indices, axis=1, **kwargs)
def broadcast_globals_to_edges(graph, name="broadcast_globals_to_edges", num_edges_hint=None): """Broadcasts the global features to the edges of a graph. Args: graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of shape `[n_graphs] + global_shape`, and `N_EDGE` field of shape `[n_graphs]`. name: (string, optional) A name for the operation. num_edges_hint: Integer indicating the total number of edges, if known. Returns: A tensor of shape `[n_edges] + global_shape`, where `n_edges = sum(graph.n_edge)`. The i-th element of this tensor is given by `globals[j]`, where j is the index of the graph the i-th edge belongs to (i.e. is such that `sum_{k < j} graphs.n_edge[k] <= i < sum_{k <= j} graphs.n_edge[k]`). Raises: ValueError: If either `graph.globals` or `graph.n_edge` is `None`. """ _validate_broadcasted_graph(graph, GLOBALS, N_EDGE) with tf.name_scope(name): return utils_tf.repeat(graph.globals, graph.n_edge, axis=0, sum_repeats_hint=num_edges_hint)
def __call__(self, graph): _validate_graph(graph, (EDGES, ), additional_message="when aggregating from edges.") num_graphs = utils_tf.get_num_graphs(graph) graph_index = tf.range(num_graphs) indices = utils_tf.repeat(graph_index, graph.n_edge, axis=0) return self._reducer(graph.edges, indices, num_graphs)
def _build(self, graph): _validate_graph(graph, (NODES, ), additional_message="when aggregating from nodes.") num_graphs = utils_tf.get_num_graphs(graph) graph_index = tf.range(num_graphs) indices = utils_tf.repeat(graph_index, graph.n_node, axis=0) return self._reducer(graph.nodes, indices, num_graphs)
def test_repeat(self): t = np.arange(24).reshape(3, 2, 4) tensor = tf.constant(t) repeats = [2, 3] axis = 1 expected = np.repeat(t, repeats, axis=axis) actual = utils_tf.repeat(tensor, repeats, axis=axis) self.assertAllEqual(expected, actual)
def test_repeat(self, shape, repeats, axis): num_elements = np.prod(shape) t = np.arange(num_elements).reshape(*shape) expected = np.repeat(t, repeats, axis=axis) tensor = tf.constant(t) repeats = tf.constant(repeats, dtype=tf.int32) actual = utils_tf.repeat(tensor, repeats, axis=axis) self.assertAllEqual(expected, actual)
def test_repeat(self): t = np.arange(24).reshape(3, 2, 4) tensor = tf.constant(t) repeats = [2, 3] axis = 1 expected = np.repeat(t, repeats, axis=axis) op = utils_tf.repeat(tensor, repeats, axis=axis) with self.test_session() as sess: actual = sess.run(op) self.assertAllEqual(expected, actual)
def __call__(self, graphs, **kwargs): destination = utils_tf.repeat(graphs.globals, graphs.n_edge) sender_features = tf.gather(graphs.nodes, graphs.senders) receiver_features = tf.gather(graphs.nodes, graphs.receivers) edge_features = graphs.edges senders = graphs.senders n_node = graphs.n_node for transformer in self._transformers: edge_features = transformer( destination, sender_features, edge_features, receiver_features, senders, n_node, **kwargs, ) out_edges = unsorted_segment_softmax( self._link_decision(edge_features, **kwargs), senders, tf.reduce_sum(n_node)) return graphs.replace(edges=out_edges)
def broadcast_globals_to_nodes_eager(graph): """Broadcasts the global features to the nodes of a graph. Args: graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of shape `[n_graphs] + global_shape`, and `N_NODE` field of shape `[n_graphs]`. name: (string, optional) A name for the operation. Returns: A tensor of shape `[n_nodes] + global_shape`, where `n_nodes = sum(graph.n_node)`. The i-th element of this tensor is given by `globals[j]`, where j is the index of the graph the i-th node belongs to (i.e. is such that `sum_{k < j} graphs.n_node[k] <= i < sum_{k <= j} graphs.n_node[k]`). Raises: ValueError: If either `graph.globals` or `graph.n_node` is `None`. """ _validate_broadcasted_graph(graph, GLOBALS, N_NODE) return utils_tf.repeat(graph.globals, graph.n_node, axis=0)
def map_(self, state, action): st = utils_tf.repeat(state, [tf.shape(action)[0]]) return self.network(tf.concat([st, action], axis=1))
def _compute_stacked_offsets(sizes, repeats): sizes = tf.cast(tf.convert_to_tensor(sizes[:-1]), tf.int32) offset_values = tf.cumsum(tf.concat([[0], sizes], 0)) return utils_tf.repeat(offset_values, repeats)
def map_(self, state, action): st = utils_tf.repeat(state,[tf.shape(action)[0]]) return self.network(tf.concat([st,action],axis=1))