def _ApplyGAT(graph): """Applies a Graph Attention layer.""" nodes, edges, receivers, senders, _, _, _ = graph # Equivalent to the sum of n_node, but statically known. try: sum_n_node = nodes.shape[0] except IndexError: raise IndexError('GAT requires node features') # First pass nodes through the node updater. nodes = attention_query_fn(nodes) # pylint: disable=g-long-lambda # We compute the softmax logits using a function that takes the # embedded sender and receiver attributes. sent_attributes = nodes[senders] received_attributes = nodes[receivers] softmax_logits = attention_logit_fn(sent_attributes, received_attributes, edges) # Compute the softmax weights on the entire tree. weights = utils.segment_softmax(softmax_logits, segment_ids=receivers, num_segments=sum_n_node) # Apply weights messages = sent_attributes * weights # Aggregate messages to nodes. nodes = utils.segment_sum(messages, receivers, num_segments=sum_n_node) # Apply an update function to the aggregated messages. nodes = node_update_fn(nodes) return graph._replace(nodes=nodes)
def _ApplyGCN(graph): """Applies a Graph Convolution layer.""" nodes, _, receivers, senders, _, _, _ = graph # First pass nodes through the node updater. nodes = update_node_fn(nodes) # Equivalent to jnp.sum(n_node), but jittable total_num_nodes = tree.tree_leaves(nodes)[0].shape[0] if add_self_edges: # We add self edges to the senders and receivers so that each node # includes itself in aggregation. # In principle, a `GraphsTuple` should partition by n_edge, but in # this case it is not required since a GCN is agnostic to whether # the `GraphsTuple` is a batch of graphs or a single large graph. conv_receivers = jnp.concatenate( (receivers, jnp.arange(total_num_nodes)), axis=0) conv_senders = jnp.concatenate( (senders, jnp.arange(total_num_nodes)), axis=0) else: conv_senders = senders conv_receivers = receivers # pylint: disable=g-long-lambda if symmetric_normalization: # Calculate the normalization values. count_edges = lambda x: utils.segment_sum( jnp.ones_like(conv_senders), x, total_num_nodes) sender_degree = count_edges(conv_senders) receiver_degree = count_edges(conv_receivers) # Pre normalize by sqrt sender degree. # Avoid dividing by 0 by taking maximum of (degree, 1). nodes = tree.tree_map( lambda x: x * jax.lax.rsqrt(jnp.maximum(sender_degree, 1.0)) [:, None], nodes, ) # Aggregate the pre normalized nodes. nodes = tree.tree_map( lambda x: aggregate_nodes_fn(x[conv_senders], conv_receivers, total_num_nodes), nodes) # Post normalize by sqrt receiver degree. # Avoid dividing by 0 by taking maximum of (degree, 1). nodes = tree.tree_map( lambda x: (x * jax.lax.rsqrt(jnp.maximum(receiver_degree, 1.0)) [:, None]), nodes, ) else: nodes = tree.tree_map( lambda x: aggregate_nodes_fn(x[conv_senders], conv_receivers, total_num_nodes), nodes) # pylint: enable=g-long-lambda return graph._replace(nodes=nodes)
def _get_interaction_network(graphs_tuple): update_node_fn = lambda n, r: jnp.concatenate((n, r), axis=-1) update_edge_fn = lambda e, s, r: jnp.concatenate((e, s, r), axis=-1) out = models.InteractionNetwork(update_edge_fn, update_node_fn)(graphs_tuple) nodes, edges, receivers, senders, _, _, _ = graphs_tuple expected_edges = jnp.concatenate((edges, nodes[senders], nodes[receivers]), axis=-1) aggregated_nodes = utils.segment_sum(expected_edges, receivers, num_segments=len(graphs_tuple.nodes)) expected_nodes = jnp.concatenate((nodes, aggregated_nodes), axis=-1) expected_out = graphs_tuple._replace(edges=expected_edges, nodes=expected_nodes) return out, expected_out
def _get_relation_network(graphs_tuple): edge_fn = lambda s, r: jnp.concatenate((s, r), axis=-1) global_fn = lambda e: e * 2 out = models.RelationNetwork(edge_fn, global_fn)(graphs_tuple) expected_edges = jnp.concatenate( (graphs_tuple.nodes[graphs_tuple.senders], graphs_tuple.nodes[graphs_tuple.receivers]), axis=-1) num_graphs = len(graphs_tuple.n_edge) edge_gr_idx = jnp.repeat(jnp.arange(num_graphs), graphs_tuple.n_edge, total_repeat_length=graphs_tuple.edges.shape[0]) aggregated_edges = utils.segment_sum(expected_edges, edge_gr_idx, num_segments=num_graphs) expected_out = graphs_tuple._replace(edges=expected_edges, globals=aggregated_edges * 2) return out, expected_out
def _get_deep_sets(graphs_tuple): node_fn = lambda n, g: jnp.concatenate((n, g), axis=-1) global_fn = lambda n: n * 2 out = models.DeepSets(node_fn, global_fn)(graphs_tuple) num_graphs = len(graphs_tuple.n_node) num_nodes = len(graphs_tuple.nodes) broadcasted_globals = jnp.repeat(graphs_tuple.globals, graphs_tuple.n_node, total_repeat_length=num_nodes, axis=0) expected_nodes = jnp.concatenate((graphs_tuple.nodes, broadcasted_globals), axis=-1) node_gr_idx = jnp.repeat(jnp.arange(num_graphs), graphs_tuple.n_node, total_repeat_length=num_nodes) expected_out = graphs_tuple._replace( nodes=expected_nodes, globals=utils.segment_sum( expected_nodes, node_gr_idx, num_segments=num_graphs) * 2) return out, expected_out
def sharded_segment_sum(data, indices, num_segments, axis_index_groups): """Segment sum over data on multiple devices.""" device_segment_sum = utils.segment_sum(data, indices, num_segments) return jax.lax.psum(device_segment_sum, axis_name='i', axis_index_groups=axis_index_groups)
def test_segment_sum(self): result = utils.segment_sum(jnp.arange(9), jnp.array([0, 1, 2, 0, 4, 0, 1, 1, 0]), 6) self.assertAllClose(result, jnp.array([16, 14, 2, 0, 4, 0]), check_dtypes=False)