def _ApplyGAT(graph): """Applies a Graph Attention layer.""" nodes, edges, receivers, senders, _, _, _ = graph # Equivalent to the sum of n_node, but statically known. try: sum_n_node = nodes.shape[0] except IndexError: raise IndexError('GAT requires node features') # First pass nodes through the node updater. nodes = attention_query_fn(nodes) # pylint: disable=g-long-lambda # We compute the softmax logits using a function that takes the # embedded sender and receiver attributes. sent_attributes = nodes[senders] received_attributes = nodes[receivers] softmax_logits = attention_logit_fn(sent_attributes, received_attributes, edges) # Compute the softmax weights on the entire tree. weights = utils.segment_softmax(softmax_logits, segment_ids=receivers, num_segments=sum_n_node) # Apply weights messages = sent_attributes * weights # Aggregate messages to nodes. nodes = utils.segment_sum(messages, receivers, num_segments=sum_n_node) # Apply an update function to the aggregated messages. nodes = node_update_fn(nodes) return graph._replace(nodes=nodes)
def test_segment_softmax(self): data = jnp.arange(9) segment_ids = jnp.array([0, 1, 2, 0, 4, 0, 1, 1, 0]) num_segments = 6 expected_out = np.array([ 3.1741429e-04, 1.8088353e-03, 1.0000000e+00, 6.3754367e-03, 1.0000000e+00, 4.7108460e-02, 2.6845494e-01, 7.2973621e-01, 9.4619870e-01 ]) with self.subTest('nojit'): result = utils.segment_softmax(data, segment_ids, num_segments) self.assertAllClose(result, expected_out, check_dtypes=True) result = utils.segment_softmax(data, segment_ids) self.assertAllClose(result, expected_out, check_dtypes=True) with self.subTest('jit'): result = jax.jit(utils.segment_softmax, static_argnums=2)(data, segment_ids, num_segments) self.assertAllClose(result, expected_out, check_dtypes=True)