def run_steps(node_embeddings):
     """Runs propagation and updates."""
     if model_type == "ggnn":
         final_embeddings = ggnn_steps(node_embeddings, edge_embeddings)
     elif model_type == "transformer":
         neighbor_mask = graph_layers.edge_mask(
             edges=example.edges,
             num_nodes=graph_metadata.num_nodes,
             num_edge_types=num_edge_types,
             forward_edge_type_indices=forward_edge_type_indices,
             reverse_edge_type_indices=reverse_edge_type_indices)
         # Allow nodes to attend to themselves
         neighbor_mask = jnp.maximum(neighbor_mask,
                                     jnp.eye(graph_metadata.num_nodes))
         final_embeddings = transformer_steps(
             node_embeddings,
             edge_embeddings,
             neighbor_mask,
             num_real_nodes_per_graph=example.graph_metadata.num_nodes)
     elif model_type == "nri_encoder":
         final_embeddings = nri_steps(
             node_embeddings,
             edge_embeddings,
             num_real_nodes_per_graph=example.graph_metadata.num_nodes)
     return final_embeddings
    def test_edge_nonzeros(self, which):
        (edges, forward_edge_type_indices,
         reverse_edge_type_indices) = self._setup_edges()
        if which == "edge_mask":
            nonzeros = graph_layers.edge_mask(
                edges,
                num_nodes=4,
                num_edge_types=3,
                forward_edge_type_indices=forward_edge_type_indices,
                reverse_edge_type_indices=reverse_edge_type_indices)
        elif which == "LearnableEdgeEmbeddings":
            embeddings, _ = graph_layers.LearnableEdgeEmbeddings.init(
                jax.random.PRNGKey(0),
                edges=edges,
                num_nodes=4,
                num_edge_types=3,
                forward_edge_type_indices=forward_edge_type_indices,
                reverse_edge_type_indices=reverse_edge_type_indices,
                embedding_dim=EDGE_EMBEDDING_DIM)
            nonzeros = jnp.any(embeddings != 0, axis=-1).astype(jnp.float32)

        expected = np.zeros([4, 4], np.float32)
        expected[2, 0] = 1
        expected[0, 1] = 1
        expected[1, 2] = 1
        expected[0, 2] = 1
        expected[1, 0] = 1
        expected[2, 1] = 1
        np.testing.assert_allclose(nonzeros, expected)
    def apply(
        self,
        graph_context,
        node_embeddings,
        edge_embeddings,
        forward_edge_types=gin.REQUIRED,
        reverse_edge_types=gin.REQUIRED,
        walk_length_log2=gin.REQUIRED,
    ):
        """Modifies edge embeddings using a uniform random walk.

    Uses an efficient repeated-squaring technique to compute the absorbing
    distribution.

    Args:
      graph_context: Input graph for this example.
      node_embeddings: Current node embeddings, as <float32[num_nodes,
        node_embedding_dim]>
      edge_embeddings: Current edge embeddings, as <float32[num_nodes,
        num_nodes, edge_embedding_dim]>
      forward_edge_types: Edge types to use in the forward direction. As a list
        of lists to allow configuring groups of edges in config files; this will
        be flattened before use.
      reverse_edge_types: Edge types to use in the reverse direction. Note that
        reversed edge types are given a separate embedding from forward edge
        types; undirected edges should be represented by adding two edges in
        opposite directions and then only using `forward_edge_types`. Also a
        list of lists, as above.
      walk_length_log2: Base-2 logarithm of maximum walk length; this determines
        how many times we will square the transition matrix (doubling the walk
        length).

    Returns:
      New node and edge embeddings. Node embeddings will not be modified. Edge
      embeddings will be modified by adding a new edge type (either embedded or
      concatenated based on graph_context.edges_are_embedded).
    """
        num_nodes = node_embeddings.shape[0]

        # pylint: disable=g-complex-comprehension
        forward_edge_type_indices = [
            graph_context.edge_types_to_indices[type_str]
            for group in forward_edge_types for type_str in group
        ]
        reverse_edge_type_indices = [
            graph_context.edge_types_to_indices[type_str]
            for group in reverse_edge_types for type_str in group
        ]
        # pylint: enable=g-complex-comprehension

        adjacency = graph_layers.edge_mask(
            edges=graph_context.bundle.edges,
            num_nodes=num_nodes,
            num_edge_types=len(graph_context.edge_types_to_indices),
            forward_edge_type_indices=forward_edge_type_indices,
            reverse_edge_type_indices=reverse_edge_type_indices)
        adjacency = jnp.maximum(adjacency, jnp.eye(num_nodes))

        absorbing_logit = self.param(
            "absorbing_logit",
            shape=(),
            initializer=lambda *_: jax.scipy.special.logit(0.1))

        absorbing_prob = jax.nn.sigmoid(absorbing_logit)
        nonabsorbing_prob = jax.nn.sigmoid(-absorbing_logit)
        walk_matrix = nonabsorbing_prob * adjacency / jnp.sum(
            adjacency, axis=1, keepdims=True)

        # A, I
        # A^2, A + I
        # (A^2)^2 = A^4, (A + I)A^2 + (A + I) = A^3 + A^2 + A + I
        # ...

        def step(state, _):
            nth_power, nth_partial_sum = state
            return (nth_power @ nth_power,
                    nth_power @ nth_partial_sum + nth_partial_sum), None

        (_, partial_sum), _ = jax.lax.scan(step,
                                           (walk_matrix, jnp.eye(num_nodes)),
                                           None,
                                           length=walk_length_log2)

        approx_visits = absorbing_prob * partial_sum
        logits = model_util.safe_logit(approx_visits)
        logits = model_util.ScaleAndShift(logits)
        edge_weights = jax.nn.sigmoid(logits)

        return (node_embeddings,
                _add_edges(edge_embeddings, edge_weights[:, :, None],
                           graph_context.edges_are_embedded))