コード例 #1
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    edge_block = blocks.EdgeBlock(
        edge_model_fn=self._edge_model_fn,
        use_edges=use_edges,
        use_receiver_nodes=use_nodes,
        use_sender_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = edge_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(input_graph.edges)
    if use_nodes:
      model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph))
      model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph))
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.nodes, output_graph.nodes)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_edges, model_inputs_out = sess.run(
          (output_graph.edges, model_inputs))

    expected_output_edges = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
コード例 #2
0
    def _build(self, graph):
        """Builds a SpringMassSimulator.

    Args:
      graph: A graphs.GraphsTuple having, for some integers N, E, G:
          - edges: Nx2 tf.Tensor of [spring_constant, rest_length] for each
            edge.
          - nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for each
            node.
          - globals: Gx2 tf.Tensor containing the gravitational constant.

    Returns:
      A graphs.GraphsTuple of the same shape as `graph`, but where:
          - edges: Holds the force [f_x, f_y] acting on each edge.
          - nodes: Holds positions and velocities after applying one step of
              Euler integration.
    """
        receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph)
        sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph)

        spring_force_per_edge = hookes_law(receiver_nodes, sender_nodes,
                                           graph.edges[..., 0:1],
                                           graph.edges[..., 1:2])
        graph = graph.replace(edges=spring_force_per_edge)

        spring_force_per_node = self._aggregator(graph)
        gravity = blocks.broadcast_globals_to_nodes(graph)
        updated_velocities = euler_integration(graph.nodes,
                                               spring_force_per_node + gravity,
                                               self._step_size)
        graph = graph.replace(nodes=updated_velocities)
        return graph
コード例 #3
0
  def test_output_values(
      self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals):
    """Compares the output of an EdgeBlock to an explicit computation."""
    input_graph = self._get_input_graph()
    edge_block = blocks.EdgeBlock(
        edge_model_fn=self._edge_model_fn,
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals)
    output_graph = edge_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(input_graph.edges)
    if use_receiver_nodes:
      model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph))
    if use_sender_nodes:
      model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph))
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.nodes, output_graph.nodes)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      output_graph_out, model_inputs_out = sess.run(
          (output_graph, model_inputs))

    expected_output_edges = model_inputs_out * self._scale
    self.assertNDArrayNear(
        expected_output_edges, output_graph_out.edges, err=1e-4)
コード例 #4
0
ファイル: modules.py プロジェクト: bellmanequation/LSC
    def _build(self, node_values, node_keys, node_queries, attention_graph):
        """Connects the multi-head self-attention module.

        The self-attention is only computed according to the connectivity of the
        input graphs, with receiver nodes attending to sender nodes.

        Args:
          node_values: Tensor containing the values associated to each of the nodes.
            The expected shape is [total_num_nodes, num_heads, key_size].
          node_keys: Tensor containing the key associated to each of the nodes. The
            expected shape is [total_num_nodes, num_heads, key_size].
          node_queries: Tensor containing the query associated to each of the nodes.
            The expected shape is [total_num_nodes, num_heads, query_size]. The
            query size must be equal to the key size.
          attention_graph: Graph containing connectivity information between nodes
            via the senders and receivers fields. Node A will only attempt to attend
            to Node B if `attention_graph` contains an edge sent by Node A and
            received by Node B.

        Returns:
          An output `graphs.GraphsTuple` with updated nodes containing the
          aggregated attended value for each of the nodes with shape
          [total_num_nodes, num_heads, value_size].

        Raises:
          ValueError: if the input graph does not have edges.
        """

        # Sender nodes put their keys and values in the edges.
        sender_keys = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_keys))
        sender_values = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_values))

        # Receiver nodes put their queries in the edges.
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            attention_graph.replace(nodes=node_queries))

        # Attention weight for each edge.
        attention_weights_logits = tf.reduce_sum(sender_keys *
                                                 receiver_queries,
                                                 axis=-1)
        normalized_attention_weights = _received_edges_normalizer(
            attention_graph.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)

        # Attending to sender values according to the weights.
        attented_edges = sender_values * \
            normalized_attention_weights[..., None]

        # Summing all of the attended values from each node.
        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        aggregated_attended_values = received_edges_aggregator(
            attention_graph.replace(edges=attented_edges))

        return attention_graph.replace(nodes=aggregated_attended_values)
コード例 #5
0
    def _build(self, attended_graph):
        """
        Feed the input through the layer
        :param attended_graph: the graph to attend to
        :return: result
        """
        stacked_edges = tf.stack([
            blocks.broadcast_sender_nodes_to_edges(attended_graph),
            blocks.broadcast_receiver_nodes_to_edges(attended_graph)
        ],
                                 axis=1)
        his = None
        for k in range(self.heads):
            e = tf.map_fn(
                lambda edge: tf.concat([
                    tf.tensordot(self.W[k], edge[0], axes=1),
                    tf.tensordot(self.W[k], edge[1], axes=1)
                ],
                                       axis=0), stacked_edges)
            attended_e = tf.exp(tf.nn.leaky_relu(self.attentions[k](e)))

            e_sender_sum = tf.math.unsorted_segment_sum(
                attended_e,
                attended_graph.senders,
                num_segments=tf.shape(attended_graph.nodes)[0])
            e_receiver_sum = tf.math.unsorted_segment_sum(
                attended_e,
                attended_graph.receivers,
                num_segments=tf.shape(attended_graph.nodes)[0])
            stacked_to_avg = tf.stack([
                attended_e,
                tf.add(tf.gather(e_sender_sum, attended_graph.senders),
                       tf.gather(e_receiver_sum, attended_graph.receivers))
            ],
                                      axis=1)
            e_avg = tf.map_fn(lambda avg: tf.divide(avg[0], avg[1]),
                              stacked_to_avg)

            Whi = tf.map_fn(
                lambda edge: tf.tensordot(self.W[k], edge, axes=1),
                blocks.broadcast_sender_nodes_to_edges(attended_graph))
            aWhi = tf.multiply(Whi, e_avg)
            hi = tf.math.unsorted_segment_sum(aWhi,
                                              attended_graph.senders,
                                              num_segments=tf.shape(
                                                  attended_graph.nodes)[0])
            if his is None:
                his = hi
            else:
                his = tf.add(his, hi)
        his = tf.divide(his, self.heads)
        return attended_graph.replace(nodes=his)
コード例 #6
0
ファイル: hand_GNN.py プロジェクト: Chenzhoujia/Graph_hand
def set_rest_lengths(graph):
  """Computes and sets rest lengths for the springs in a physical system.

  The rest length is taken to be the distance between each edge's nodes.

  Args:
    graph: a graphs.GraphsTuple having, for some integers N, E:
        - nodes: Nx5 Tensor of [x, y, _, _, _] for each node.
        - edges: Ex2 Tensor of [spring_constant, _] for each edge.

  Returns:
    The input graph, but with [spring_constant, rest_length] for each edge.
  """
  receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph)
  sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph)
  rest_length = tf.norm(
      receiver_nodes[..., :2] - sender_nodes[..., :2], axis=-1, keep_dims=True)
  return graph.replace(
      edges=tf.concat([graph.edges[..., :1], rest_length], axis=-1))
コード例 #7
0
    def _build(self, graph):
        agg_receiver_nodes_features = blocks.broadcast_receiver_nodes_to_edges(
            graph)
        agg_sender_nodes_features = blocks.broadcast_sender_nodes_to_edges(
            graph)

        # aggreate across replicas

        replica_ctx = tf.distribute.get_replica_context()
        agg_receiver_nodes_features = replica_ctx.all_reduce(
            "sum", agg_receiver_nodes_features)
        agg_sender_nodes_features = replica_ctx.all_reduce(
            "sum", agg_sender_nodes_features)

        edges_to_collect = [
            graph.edges, agg_receiver_nodes_features, agg_sender_nodes_features
        ]
        collected_edges = tf.concat(edges_to_collect, axis=-1)
        updated_edges = self._edge_model(collected_edges)
        return graph.replace(edges=updated_edges)
コード例 #8
0
    def _build(self, node_values, node_keys, node_queries, attention_graph):
        # Sender nodes put their keys and values in the edges.
        # [total_num_edges, num_heads, query_size]
        sender_keys = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_keys))

        # [total_num_edges, num_heads, value_size]
        sender_values = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_values))

        # Receiver nodes put their queries in the edges.
        # [total_num_edges, num_heads, key_size]
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            attention_graph.replace(nodes=node_queries))

        # Attention weight for each edge.
        # [total_num_edges, num_heads]
        attention_weights_logits = tf.reduce_sum(
            sender_keys * tf.transpose(receiver_queries), axis=-1)
        normalized_attention_weights = _received_edges_normalizer(
            attention_graph.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)

        # Attending to sender values according to the weights.
        # [total_num_edges, num_heads, embedding_size]
        attented_edges = sender_values * normalized_attention_weights[...,
                                                                      None]

        # Summing all of the attended values from each node.
        # [total_num_nodes, num_heads, embedding_size]
        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        aggregated_attended_values = received_edges_aggregator(
            attention_graph.replace(edges=attented_edges))

        return attention_graph.replace(nodes=aggregated_attended_values)
コード例 #9
0
    def _build(self, graph_features):
        """Connects the multi-head self-attention module.

    Uses edge_features to compute key, values and node_features
    for queries.

    The self-attention is only computed according to the connectivity of the
    input graphs, with receiver nodes attending to sender nodes.

    Args:
      graph_features: Graph containing connectivity information between nodes
        via the senders and receivers fields. Node A will only attempt to attend
        to Node B if `attention_graph` contains an edge sent by Node A and
        received by Node B.

    Returns:
      An output `graphs.GraphsTuple` with updated nodes containing the
      aggregated attended value for each of the nodes with shape
      [total_num_nodes, num_heads, value_size].

    Raises:
      ValueError: if the input graph does not have edges.
    """
        """
    # TODO(arc): Figure out how to incorporate edge information into
                 attention updates.
    """
        edges = self._edge_block(graph_features).edges
        num_heads = self.num_heads
        key_size = self.key_size
        value_size = self.value_size
        node_embed_dim = tf.shape(graph_features.nodes)[-1]

        # [total_num_nodes, d] => [total_num_nodes, key_size * num_heads]
        q = self._attention_node_projection_model(graph_features.nodes)

        q = tf.reshape(
            q, [tf.reduce_sum(graph_features.n_node), num_heads, key_size])

        # [total_num_edges, (key_size + value_size) * num_heads]
        # project edge features to get key, values
        kv = self._attention_edge_projection_model(edges)
        kv = tf.reshape(kv, [-1, num_heads, key_size + value_size])
        # k => [total_num_edges, num_heads, key_size]
        # v => [total_num_edges, num_heads, value_size]
        k, v = tf.split(kv, [key_size, value_size], -1)

        sender_keys = k
        sender_values = v
        # Receiver nodes put their queries in the edges.
        # [total_num_edges, num_heads, key_size]
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            graph_features.replace(nodes=q))

        # Attention weight for each edge.
        # [total_num_edges, num_heads, 1]
        attention_weights_logits = snt.BatchApply(
            self._query_key_product_model)(tf.concat(
                [sender_keys, receiver_queries], axis=-1))
        # [total_num_edges, num_heads]
        attention_weights_logits = tf.squeeze(attention_weights_logits, -1)

        # compute softmax weights
        # [total_num_edges, num_heads]
        normalized_attention_weights = _received_edges_normalizer(
            graph_features.replace(edges=attention_weights_logits),
            normalizer=_unsorted_segment_softmax)

        # Attending to sender values according to the weights.
        # [total_num_edges, num_heads, value_size]
        attented_edges = sender_values * normalized_attention_weights[...,
                                                                      None]

        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        # Summing all of the attended values from each node.
        # [total_num_nodes, num_heads, value_size]
        aggregated_attended_values = received_edges_aggregator(
            graph_features.replace(edges=attented_edges))

        # concatenate all the heads and project to required dimension.
        # cast to [total_num_nodes, num_heads * value_size]
        aggregated_attended_values = tf.reshape(aggregated_attended_values,
                                                [-1, num_heads * value_size])
        # -> [total_num_nodes, node_embed_dim]
        aggregated_attended_values = self._node_model(
            aggregated_attended_values)

        return self._global_block(
            graph_features.replace(nodes=aggregated_attended_values,
                                   edges=edges))
コード例 #10
0
tvars = graph_network.trainable_variables
print('')

###############
# broadcast

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
    nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))

############
# aggregate

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])

reducer = tf.math.unsorted_segment_sum  #######yr
updated_edges_to_globals = graphs_tuple.replace(
    globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
    globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_sent_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
updated_received_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))