예제 #1
0
    def _ApplyGAT(graph):
        """Applies a Graph Attention layer."""
        nodes, edges, receivers, senders, _, _, _ = graph
        # Equivalent to the sum of n_node, but statically known.
        try:
            sum_n_node = nodes.shape[0]
        except IndexError:
            raise IndexError('GAT requires node features')

        # First pass nodes through the node updater.
        nodes = attention_query_fn(nodes)
        # pylint: disable=g-long-lambda
        # We compute the softmax logits using a function that takes the
        # embedded sender and receiver attributes.
        sent_attributes = nodes[senders]
        received_attributes = nodes[receivers]
        softmax_logits = attention_logit_fn(sent_attributes,
                                            received_attributes, edges)

        # Compute the softmax weights on the entire tree.
        weights = utils.segment_softmax(softmax_logits,
                                        segment_ids=receivers,
                                        num_segments=sum_n_node)
        # Apply weights
        messages = sent_attributes * weights
        # Aggregate messages to nodes.
        nodes = utils.segment_sum(messages, receivers, num_segments=sum_n_node)

        # Apply an update function to the aggregated messages.
        nodes = node_update_fn(nodes)
        return graph._replace(nodes=nodes)
예제 #2
0
파일: utils_test.py 프로젝트: BwRy/jraph
 def test_segment_softmax(self):
     data = jnp.arange(9)
     segment_ids = jnp.array([0, 1, 2, 0, 4, 0, 1, 1, 0])
     num_segments = 6
     expected_out = np.array([
         3.1741429e-04, 1.8088353e-03, 1.0000000e+00, 6.3754367e-03,
         1.0000000e+00, 4.7108460e-02, 2.6845494e-01, 7.2973621e-01,
         9.4619870e-01
     ])
     with self.subTest('nojit'):
         result = utils.segment_softmax(data, segment_ids, num_segments)
         self.assertAllClose(result, expected_out, check_dtypes=True)
         result = utils.segment_softmax(data, segment_ids)
         self.assertAllClose(result, expected_out, check_dtypes=True)
     with self.subTest('jit'):
         result = jax.jit(utils.segment_softmax,
                          static_argnums=2)(data, segment_ids, num_segments)
         self.assertAllClose(result, expected_out, check_dtypes=True)