Esempio n. 1
0
    def _ApplyGAT(graph):
        """Applies a Graph Attention layer."""
        nodes, edges, receivers, senders, _, _, _ = graph
        # Equivalent to the sum of n_node, but statically known.
        try:
            sum_n_node = nodes.shape[0]
        except IndexError:
            raise IndexError('GAT requires node features')

        # First pass nodes through the node updater.
        nodes = attention_query_fn(nodes)
        # pylint: disable=g-long-lambda
        # We compute the softmax logits using a function that takes the
        # embedded sender and receiver attributes.
        sent_attributes = nodes[senders]
        received_attributes = nodes[receivers]
        softmax_logits = attention_logit_fn(sent_attributes,
                                            received_attributes, edges)

        # Compute the softmax weights on the entire tree.
        weights = utils.segment_softmax(softmax_logits,
                                        segment_ids=receivers,
                                        num_segments=sum_n_node)
        # Apply weights
        messages = sent_attributes * weights
        # Aggregate messages to nodes.
        nodes = utils.segment_sum(messages, receivers, num_segments=sum_n_node)

        # Apply an update function to the aggregated messages.
        nodes = node_update_fn(nodes)
        return graph._replace(nodes=nodes)
Esempio n. 2
0
    def _ApplyGCN(graph):
        """Applies a Graph Convolution layer."""
        nodes, _, receivers, senders, _, _, _ = graph

        # First pass nodes through the node updater.
        nodes = update_node_fn(nodes)
        # Equivalent to jnp.sum(n_node), but jittable
        total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
        if add_self_edges:
            # We add self edges to the senders and receivers so that each node
            # includes itself in aggregation.
            # In principle, a `GraphsTuple` should partition by n_edge, but in
            # this case it is not required since a GCN is agnostic to whether
            # the `GraphsTuple` is a batch of graphs or a single large graph.
            conv_receivers = jnp.concatenate(
                (receivers, jnp.arange(total_num_nodes)), axis=0)
            conv_senders = jnp.concatenate(
                (senders, jnp.arange(total_num_nodes)), axis=0)
        else:
            conv_senders = senders
            conv_receivers = receivers

        # pylint: disable=g-long-lambda
        if symmetric_normalization:
            # Calculate the normalization values.
            count_edges = lambda x: utils.segment_sum(
                jnp.ones_like(conv_senders), x, total_num_nodes)
            sender_degree = count_edges(conv_senders)
            receiver_degree = count_edges(conv_receivers)

            # Pre normalize by sqrt sender degree.
            # Avoid dividing by 0 by taking maximum of (degree, 1).
            nodes = tree.tree_map(
                lambda x: x * jax.lax.rsqrt(jnp.maximum(sender_degree, 1.0))
                [:, None],
                nodes,
            )
            # Aggregate the pre normalized nodes.
            nodes = tree.tree_map(
                lambda x: aggregate_nodes_fn(x[conv_senders], conv_receivers,
                                             total_num_nodes), nodes)
            # Post normalize by sqrt receiver degree.
            # Avoid dividing by 0 by taking maximum of (degree, 1).
            nodes = tree.tree_map(
                lambda x: (x * jax.lax.rsqrt(jnp.maximum(receiver_degree, 1.0))
                           [:, None]),
                nodes,
            )
        else:
            nodes = tree.tree_map(
                lambda x: aggregate_nodes_fn(x[conv_senders], conv_receivers,
                                             total_num_nodes), nodes)
        # pylint: enable=g-long-lambda
        return graph._replace(nodes=nodes)
Esempio n. 3
0
def _get_interaction_network(graphs_tuple):
    update_node_fn = lambda n, r: jnp.concatenate((n, r), axis=-1)
    update_edge_fn = lambda e, s, r: jnp.concatenate((e, s, r), axis=-1)
    out = models.InteractionNetwork(update_edge_fn,
                                    update_node_fn)(graphs_tuple)
    nodes, edges, receivers, senders, _, _, _ = graphs_tuple
    expected_edges = jnp.concatenate((edges, nodes[senders], nodes[receivers]),
                                     axis=-1)
    aggregated_nodes = utils.segment_sum(expected_edges,
                                         receivers,
                                         num_segments=len(graphs_tuple.nodes))
    expected_nodes = jnp.concatenate((nodes, aggregated_nodes), axis=-1)
    expected_out = graphs_tuple._replace(edges=expected_edges,
                                         nodes=expected_nodes)
    return out, expected_out
Esempio n. 4
0
def _get_relation_network(graphs_tuple):
    edge_fn = lambda s, r: jnp.concatenate((s, r), axis=-1)
    global_fn = lambda e: e * 2
    out = models.RelationNetwork(edge_fn, global_fn)(graphs_tuple)
    expected_edges = jnp.concatenate(
        (graphs_tuple.nodes[graphs_tuple.senders],
         graphs_tuple.nodes[graphs_tuple.receivers]),
        axis=-1)
    num_graphs = len(graphs_tuple.n_edge)
    edge_gr_idx = jnp.repeat(jnp.arange(num_graphs),
                             graphs_tuple.n_edge,
                             total_repeat_length=graphs_tuple.edges.shape[0])
    aggregated_edges = utils.segment_sum(expected_edges,
                                         edge_gr_idx,
                                         num_segments=num_graphs)
    expected_out = graphs_tuple._replace(edges=expected_edges,
                                         globals=aggregated_edges * 2)
    return out, expected_out
Esempio n. 5
0
def _get_deep_sets(graphs_tuple):
    node_fn = lambda n, g: jnp.concatenate((n, g), axis=-1)
    global_fn = lambda n: n * 2
    out = models.DeepSets(node_fn, global_fn)(graphs_tuple)
    num_graphs = len(graphs_tuple.n_node)
    num_nodes = len(graphs_tuple.nodes)
    broadcasted_globals = jnp.repeat(graphs_tuple.globals,
                                     graphs_tuple.n_node,
                                     total_repeat_length=num_nodes,
                                     axis=0)
    expected_nodes = jnp.concatenate((graphs_tuple.nodes, broadcasted_globals),
                                     axis=-1)
    node_gr_idx = jnp.repeat(jnp.arange(num_graphs),
                             graphs_tuple.n_node,
                             total_repeat_length=num_nodes)
    expected_out = graphs_tuple._replace(
        nodes=expected_nodes,
        globals=utils.segment_sum(
            expected_nodes, node_gr_idx, num_segments=num_graphs) * 2)
    return out, expected_out
Esempio n. 6
0
def sharded_segment_sum(data, indices, num_segments, axis_index_groups):
    """Segment sum over data on multiple devices."""
    device_segment_sum = utils.segment_sum(data, indices, num_segments)
    return jax.lax.psum(device_segment_sum,
                        axis_name='i',
                        axis_index_groups=axis_index_groups)
Esempio n. 7
0
 def test_segment_sum(self):
     result = utils.segment_sum(jnp.arange(9),
                                jnp.array([0, 1, 2, 0, 4, 0, 1, 1, 0]), 6)
     self.assertAllClose(result,
                         jnp.array([16, 14, 2, 0, 4, 0]),
                         check_dtypes=False)