def mock_model(example):
     del example
     return model_util.safe_logit(
         jnp.array([
             [0.2, 0.0, 0.1, 0.0, 0.0],
             [0.0, 0.0, 1.0, 0.0, 0.0],
             [0.0, 0.0, 0.3, 0.7, 0.0],
             [0.9, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0],
         ]))
def automaton_model(example,
                    graph_metadata,
                    edge_types_to_indices,
                    variant_edge_types=(),
                    platt_scale=False,
                    with_backtrack=True):
    """Automaton-based module for edge supervision task.

  Args:
    example: Example to run the automaton on.
    graph_metadata: Statically-known metadata about the graph size. If
      encoded_graph is padded, this should reflect the padded size, not the
      original size.
    edge_types_to_indices: Mapping from edge type names to edge type indices.
    variant_edge_types: Edge types to use as variants. Assumes without checking
      that the given variants are mutually exclusive (at most one edge of one of
      these types exists between any pair of nodes).
    platt_scale: Whether to scale and shift the logits produced by the
      automaton. This can be viewed as a form of Platt scaling applied to the
      automaton logits. If True, this allows the model's output probabilities to
      sum to more than 1, so that it can express one-to-many relations.
    with_backtrack: Whether the automaton can restart the search as an action.

  Returns:
    <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted
    adjacency matrix corresponding to the predicted output edges.
  """
    if variant_edge_types:
        variant_edge_type_indices = [
            edge_types_to_indices[type_str] for type_str in variant_edge_types
        ]
        num_edge_types = len(edge_types_to_indices)
        variant_weights = variants_from_edges(example, graph_metadata,
                                              variant_edge_type_indices,
                                              num_edge_types)
    else:
        variant_weights = None

    absorbing_probs = automaton_layer.FiniteStateGraphAutomaton(
        encoded_graph=example.automaton_graph,
        variant_weights=variant_weights,
        static_metadata=graph_metadata,
        dynamic_metadata=example.graph_metadata,
        builder=automaton_builder.AutomatonBuilder(
            py_ast_graphs.SCHEMA, with_backtrack=with_backtrack),
        num_out_edges=1,
        share_states_across_edges=True).squeeze(axis=0)

    logits = model_util.safe_logit(absorbing_probs)

    if platt_scale:
        logits = model_util.ScaleAndShift(logits)

    return logits
        def mock_model_def(example):
            del example
            side_outputs.SideOutput(
                -jnp.arange(5).astype("float32").reshape((1, 5)),
                name="one_sample_log_prob_per_edge_per_node")
            side_outputs.SideOutput(0.3, name="one_sample_reward_baseline")

            return model_util.safe_logit(
                jnp.array([
                    [0.0, 0.0, 0.0, 0.0, 0.0],
                    [0.0, 0.0, 1.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0, 1.0, 0.0],
                    [0.0, 0.0, 0.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0, 0.0, 0.0],
                ]))
Ejemplo n.º 4
0
 def test_safe_logit(self):
     probs = jnp.array([0, 1e-20, 1e-3, 0.9, 1])
     logits = model_util.safe_logit(probs)
     self.assertTrue(np.all(np.isfinite(logits)))
     np.testing.assert_allclose(logits[1:3],
                                jax.scipy.special.logit(probs[1:3]))
def sample_loss_fn(
    model,
    padded_example_and_rng,
    target_edge_index,
    num_edge_types,
    baseline_weight=0.001,
    num_rollouts=1,
    leave_one_out_baseline=False,
):
    """Compute a single-sample version of the loss.

  Used for running sample-based baseline.

  Args:
    model: Model to run on the example.
    padded_example_and_rng: Example to extract targets from, with RNG.
    target_edge_index: Index of the target edge type.
    num_edge_types: How many edge types there are.
    baseline_weight: How much weight to give to learned baseline loss term.
    num_rollouts: How many rollouts to use.
    leave_one_out_baseline: Whether to use leave-one-out baseline instead of
      learned baseline.

  Returns:
    Tuple (output_logits, targets, valid_mask, loss, metrics_dict).
  """
    padded_example, rng = padded_example_and_rng

    @functools.partial(jax.vmap,
                       out_axes=(0, None, None, None, 0, 0, None, None))
    def go(rng):
        (output_logits, targets, valid_mask, num_nodes,
         captured) = extract_outputs_and_targets(model, (padded_example, rng),
                                                 target_edge_index,
                                                 num_edge_types)

        valid_nodes = (jnp.arange(output_logits.shape[-1]) < num_nodes).astype(
            jnp.float32)
        log_prob, = [
            v for k, v in captured.items()
            if k.endswith("one_sample_log_prob_per_edge_per_node")
        ]
        log_prob = log_prob.squeeze(0) * valid_nodes
        learned_baseline, = [
            v for k, v in captured.items()
            if k.endswith("one_sample_reward_baseline")
        ]
        # Compute reward: +1 for doing the correct thing, 0 for doing anything else.
        output_probs = jax.nn.sigmoid(output_logits) * valid_mask
        num_targets_per_node = jnp.sum(targets.astype(jnp.int32), axis=-1)
        no_targets = (num_targets_per_node == 0)
        fail_prob = 1 - jnp.sum(output_probs, axis=-1)
        reward_per_node = (
            jnp.sum(output_probs * targets.astype(jnp.float32), axis=-1) +
            fail_prob * no_targets.astype(jnp.float32)) * valid_nodes

        return (output_logits, targets, valid_mask, num_nodes, reward_per_node,
                log_prob, learned_baseline, valid_nodes)

    (output_logits, targets, valid_mask, num_nodes, reward_per_node, log_prob,
     learned_baseline, valid_nodes) = go(jax.random.split(rng, num_rollouts))

    if leave_one_out_baseline:
        baseline = (jnp.sum(reward_per_node, axis=0, keepdims=True) -
                    reward_per_node) / (num_rollouts - 1)
    else:
        baseline = learned_baseline

    shifted_reward = (reward_per_node - baseline) * valid_nodes
    # REINFORCE: scale shifted reward by log probs
    reinforce_virtual_loss = jnp.sum(
        -log_prob * jax.lax.stop_gradient(shifted_reward)) / (num_nodes *
                                                              num_rollouts)

    if not leave_one_out_baseline:
        # Penalty for baseline to bring it close to 0
        baseline_penalty = baseline_weight * jnp.sum(
            jnp.square(shifted_reward)) / (num_nodes * num_rollouts)
    else:
        baseline_penalty = jnp.zeros([])

    mean_logits = model_util.safe_logit(
        jnp.mean(jax.scipy.special.expit(output_logits), axis=0))

    virtual_loss = reinforce_virtual_loss + baseline_penalty
    return mean_logits, targets, valid_mask, virtual_loss, {
        "reward": jnp.sum(reward_per_node) / (num_nodes * num_rollouts),
        "shifted_reward": jnp.sum(shifted_reward) / (num_nodes * num_rollouts),
        "policy_log_prob": jnp.sum(log_prob) / (num_nodes * num_rollouts),
        "learned_baseline": learned_baseline,
        "baseline_penalty": baseline_penalty,
        "reinforce_term": reinforce_virtual_loss,
    }
Ejemplo n.º 6
0
    def apply(
        self,
        encoded_graph,
        dynamic_metadata,
        static_metadata,
        builder,
        num_out_edges=1,
        num_intermediate_states=0,
        variant_weights=None,
        steps=None,
        share_states_across_edges=True,
        backtrack_fails_prob=0.001,
        initialization_noise_factor=0.01,
        legacy_initialize=True,
        initialize_smoothing=0.001,
        use_gate_parameterization=False,
        gate_noise=0.2,
        logit_scaling="learned",
        dynamic_scaling_absolute_shift=1.0,
        dynamic_scaling_relative_shift=1.0,
        estimator_type="solver",
        sampling_max_possible_transitions=None,
    ):
        """Apply the graph finite-state automaton layer.

    Args:
      encoded_graph: Graph structure to run the automaton on.
      dynamic_metadata: Metadata about the actual unpadded graph. Unused.
      static_metadata: Statically-known metadata about the graph size. If
        encoded_graph is padded, this should reflect the padded size, not the
        original size.
      builder: Builder object that encoded the graph. Determines the shape of
        the automaton parameters.
      num_out_edges: How many distinct "edge types" to produce. Each of these
        corresponds to a different automaton start state.
      num_intermediate_states: How many intermediate states to include. These
        states are only used inside the automaton, and don't get exposed as
        start states for an edge type. If share_states_across_edges is True, all
        of the edge types will share the same set of intermediate states. If
        share_states_across_edges is False, each edge type will have its OWN set
        of `num_intermediate_states` states.
      variant_weights: Optional <float32[num_nodes, num_nodes, num_variants]> or
        <float32[num_nodes, num_nodes, num_out_edges, num_variants]> array that
        has nonnegative elements that sum to 1 along the last axis. If provided,
        variant_weights[i, j, (e,) v] specifies how much policy variant v should
        be used (for edge type e) when starting from node i and arriving at
        intermediate node j. Variants correspond to the start-node-conditioned
        observations described in Appendix C.2.
      steps: How many steps to use when solving the automaton. If not provided,
        uses the number of nodes in the graph. Smaller numbers of steps may make
        it impossible to reach certain nodes in the graph; larger numbers may
        allow additional backtracking and state changes at the expense of
        additional compute time.
      share_states_across_edges: Whether the different edge types share the same
        set of states. If True, any state can transition to any other state
        (even to a different start state). If False, every edge type gets a
        separate set of `1 + num_intermediate_states` that are not shared; in
        other words, every output edge uses a distinct finite state machine.
      backtrack_fails_prob: Backtracking decay factor; determines how often the
        automaton halts when it tries to take the BACKTRACK action. If the
        automaton attempts to backtrack with close to 100% probability, this
        ensures numerical stability and counteracts noise.
      initialization_noise_factor: How much noise to use when initializing the
        automaton policy (see AutomatonBuilder.initialize_routing_params)
      legacy_initialize: Whether to use legacy initialization, which sets the
        log-space softmax weights as Dirichlet random samples (instead of
        setting the softmax output distribution as Dirichlet random samples).
        Defaults to True so that we can reload old configs; new runs should use
        False.
      initialize_smoothing: Controls how much we smooth the initial parameters
        toward a uniform distribution. For small values, this is effectively a
        lower bound of the probability we take each action. If zero, we can
        sample arbitrarily small starting probabilities from the Dirichlet
        distribution. (Specifically, we adjust the sampled probabilities as
        `p_init(x) = (p_sampled(x) + c)/(1 + n*c)` where `n` is the number of
        possible actions and `c = initialize_smoothing`.
      use_gate_parameterization: Whether to use gate parameterization instead of
        default parameterization. If so, the other initialization args are
        ignored.
      gate_noise: Logistic noise for gate parameterization.
      logit_scaling: One of "none", "learned", "dynamic"
      dynamic_scaling_absolute_shift: For dynamic scaling, how much extra to
        shift logits, in an absolute sense, after shifting for mean magnitude.
      dynamic_scaling_relative_shift: For dynamic scaling, how much extra to
        shift logits, relative to standard deviation of magnitude, after
        shifting for mean magnitude
      estimator_type: One of "solver", "one_sample".
      sampling_max_possible_transitions: Number of possible outgoing transitions
        for any given start node. Used to accelerate the sampling process when
        estimator is "one_sample".

    Returns:
      <float32[num_out_edges, num_nodes, num_nodes]> weighted adjacency
      matrix for `num_out_edges` new edge types.
    """
        del dynamic_metadata
        num_nodes = static_metadata.num_nodes
        steps = steps if steps is not None else num_nodes
        assert logit_scaling in ("none", "learned", "dynamic")

        if variant_weights is None:
            num_variants = 1
            variant_out_edge_axis = None
            variant_weights = jnp.ones((num_nodes, num_nodes, num_variants))
        elif variant_weights.ndim == 3:
            num_variants = variant_weights.shape[-1]
            variant_out_edge_axis = None
            if variant_weights.shape != (num_nodes, num_nodes, num_variants):
                raise ValueError(
                    f"variant_weights shape {variant_weights.shape} doesn't match "
                    f"expected shape ({num_nodes}, {num_nodes}, <anything>)")
        elif variant_weights.ndim == 4:
            num_variants = variant_weights.shape[-1]
            variant_out_edge_axis = 2
            if variant_weights.shape != (num_nodes, num_nodes, num_out_edges,
                                         num_variants):
                raise ValueError(
                    f"variant_weights shape {variant_weights.shape} doesn't match "
                    "expected shape"
                    f"({num_nodes}, {num_nodes}, {num_out_edges}, <anything>)")
        else:
            raise ValueError(
                f"Invalid variant_weights shape {variant_weights.shape};"
                " expected 3 or 4 axes")

        if share_states_across_edges:
            num_fsm_states = num_out_edges + num_intermediate_states

            # Initialize or retrieve the automaton parameters; these parameters are
            # shared across all edge types.
            if use_gate_parameterization:

                def shared_routing_initializer(rng_key, unused_shape):
                    return builder.initialize_routing_gates(
                        key=rng_key,
                        num_fsm_states=num_fsm_states,
                        num_variants=num_variants,
                        logistic_noise=gate_noise)

                routing_gates = self.param(
                    "routing_gate_logits_shared",
                    shape=None,
                    initializer=shared_routing_initializer)
                routing_gates = side_outputs.encourage_discrete_logits(
                    routing_gates,
                    distribution_type="binary",
                    name="routing_gate_logits_shared")
                routing_params = builder.routing_gates_to_probs(routing_gates)

            else:

                def shared_routing_initializer(rng_key, unused_shape):
                    routing_probs = builder.initialize_routing_params(
                        key=rng_key,
                        num_fsm_states=num_fsm_states,
                        num_variants=num_variants,
                        noise_factor=initialization_noise_factor)
                    if legacy_initialize:
                        return routing_probs
                    return jax.tree_map(
                        lambda x: jnp.log(x + initialize_smoothing),
                        routing_probs)

                log_routing_params = self.param(
                    "log_routing_params_shared",
                    shape=None,
                    initializer=shared_routing_initializer)
                routing_params = builder.routing_softmax(log_routing_params)

            # Don't precompute constants if we are tracing an XLA computation; wait
            # until we know a value for our parameters by adding a fake data
            # dependence.
            trigger = jax.tree_leaves(routing_params)[0]
            variant_weights = jax.lax.tie_in(trigger, variant_weights)

            # Build the automaton on the provided graph.
            transition_matrix = builder.build_transition_matrix(
                routing_params, encoded_graph, static_metadata)

            # Each edge type is a start state.
            if num_intermediate_states > 0:
                start_machine_states = jnp.concatenate([
                    jax.lax.tie_in(trigger, jnp.eye(num_out_edges)),
                    jax.lax.tie_in(
                        trigger,
                        jnp.zeros((num_out_edges, num_intermediate_states)))
                ], 1)
            else:
                start_machine_states = jax.lax.tie_in(trigger,
                                                      jnp.eye(num_out_edges))

            start_machine_states = jnp.broadcast_to(
                start_machine_states,
                (num_nodes, num_out_edges, num_fsm_states))

            # Solve for absorbing distribution for each of the starting states by
            # vmapping across the dimensions that depend on the start state.
            if estimator_type == "solver":
                absorbing_solution = jax_util.vmap_with_kwargs(
                    automaton_builder.all_nodes_absorbing_solve,
                    variant_weights_axis=variant_out_edge_axis,
                    start_machine_states_axis=1)(
                        builder=builder,
                        transition_matrix=transition_matrix,
                        variant_weights=variant_weights,
                        start_machine_states=start_machine_states,
                        steps=steps,
                        backtrack_fails_prob=backtrack_fails_prob)
            elif estimator_type == "one_sample":
                assert sampling_max_possible_transitions is not None
                rollout_each_node_fn = jax_util.vmap_with_kwargs(
                    automaton_sampling.roll_out_transitions,
                    variant_weights_axis=0,
                    start_machine_state_axis=0,
                    node_index_axis=0,
                    rng_axis=0)
                rollout_each_edgetype_fn = jax_util.vmap_with_kwargs(
                    rollout_each_node_fn,
                    variant_weights_axis=variant_out_edge_axis,
                    start_machine_state_axis=1,
                    rng_axis=0)
                all_states = rollout_each_edgetype_fn(
                    builder=builder,
                    transition_matrix=transition_matrix,
                    variant_weights=variant_weights,
                    start_machine_state=start_machine_states,
                    node_index=jnp.arange(num_nodes),
                    steps=steps,
                    max_possible_transitions=sampling_max_possible_transitions,
                    rng=jax.random.split(flax.nn.make_rng(),
                                         num_out_edges * num_nodes).reshape(
                                             [num_out_edges, num_nodes, -1]))

                def set_absorbing(final_node, succeeded):
                    return jnp.zeros([num_nodes]).at[final_node].set(succeeded)

                # absorbing_solution is [num_out_edges, num_nodes, num_nodes]
                absorbing_solution = jax.vmap(jax.vmap(set_absorbing))(
                    all_states.final_node, all_states.succeeded)

                side_outputs.SideOutput(
                    all_states.log_prob,
                    name="one_sample_log_prob_per_edge_per_node")

                # Somewhat of a hack: associate the log prob with its own learned
                # baseline as a side output, so it can be trained alongside the rest
                # of the model, but don't do anything with it until the loss function.
                one_sample_reward_baseline = self.param(
                    "one_sample_reward_baseline",
                    shape=(),
                    initializer=jax.nn.initializers.zeros)
                side_outputs.SideOutput(one_sample_reward_baseline,
                                        name="one_sample_reward_baseline")
            else:
                raise ValueError(f"Invalid estimator {estimator_type}")

            # Rescale the logits.
            logits = model_util.safe_logit(absorbing_solution)
            if logit_scaling == "learned":
                # Learned scaling and shifting.
                logits = model_util.ScaleAndShift(logits)
            elif logit_scaling == "dynamic":
                # Dynamic scaling only implemented with gates.
                assert use_gate_parameterization
                # First, quantify how discrete the gates are. Conceptually, we want
                # to quantify how far away from zero the logits are, in a differentiable
                # way. To make it smooth, use logsumexp:
                relevant_gates = [
                    routing_gates.move_gates, routing_gates.accept_gates
                ]
                soft_abs_logits = [
                    jnp.logaddexp(g, -g) for g in relevant_gates
                ]
                # Take a mean and standard deviation over these logits to summarize.
                logit_mean = (sum(jnp.sum(x) for x in soft_abs_logits) /
                              sum(x.size for x in soft_abs_logits))
                logit_var = (sum(
                    jnp.sum(jnp.square(x - logit_mean))
                    for x in soft_abs_logits) / sum(x.size
                                                    for x in soft_abs_logits))
                logit_std = jnp.sqrt(logit_var)
                side_outputs.SideOutput(logit_mean, name="gate_logit_abs_mean")
                side_outputs.SideOutput(logit_std, name="gate_logit_abs_std")
                # Now, use these to choose an adjustment factor. Intuitively, the "off"
                # gates should be centered around (-logit_mean). So anything that
                # gets sufficiently more mass than that should be "on". We consider
                # two notions of "sufficiently": either relative to the variance in
                # logits, or absolute.
                shift_threshold = (-logit_mean +
                                   dynamic_scaling_absolute_shift +
                                   dynamic_scaling_relative_shift * logit_std)
                side_outputs.SideOutput(shift_threshold,
                                        name="shift_threshold")
                # Adjust so that values at the shift threshold are mapped to edges of
                # weight 0.5
                logits = logits - shift_threshold

        else:
            num_fsm_states = 1 + num_intermediate_states

            if estimator_type != "solver":
                raise NotImplementedError(
                    "Sampling estimators not implemented for unshared states.")

            # Different automaton parameters for each start state
            if use_gate_parameterization:

                def unshared_routing_initializer(rng_key, unused_shape):
                    key_per_edge = jax.random.split(rng_key, num_out_edges)
                    return jax_util.vmap_with_kwargs(
                        builder.initialize_routing_gates,
                        key_axis=0)(key=key_per_edge,
                                    num_fsm_states=num_fsm_states,
                                    num_variants=num_variants,
                                    logistic_noise=gate_noise)

                routing_gates = self.param(
                    "routing_gate_logits_unshared",
                    shape=None,
                    initializer=unshared_routing_initializer)
                routing_gates = side_outputs.encourage_discrete_logits(
                    routing_gates,
                    distribution_type="binary",
                    name="routing_gate_logits_unshared")
                stacked_routing_params = jax.vmap(
                    builder.routing_gates_to_probs)(routing_gates)

            else:

                def unshared_routing_initializer(rng_key, unused_shape):
                    key_per_edge = jax.random.split(rng_key, num_out_edges)
                    routing_probs = jax_util.vmap_with_kwargs(
                        builder.initialize_routing_params,
                        key_axis=0)(key=key_per_edge,
                                    num_fsm_states=num_fsm_states,
                                    num_variants=num_variants,
                                    noise_factor=initialization_noise_factor)
                    if legacy_initialize:
                        return routing_probs
                    return jax.tree_map(
                        lambda x: jnp.log(x + initialize_smoothing),
                        routing_probs)

                log_routing_params = self.param(
                    "log_routing_params_unshared",
                    shape=None,
                    initializer=unshared_routing_initializer)
                stacked_routing_params = jax.vmap(
                    builder.routing_softmax)(log_routing_params)

            def solve_one(one_edge_routing_params, one_edge_variant_weights):
                """Run one of the edge-specific automata."""
                # Build the automaton on the provided graph.
                transition_matrix = builder.build_transition_matrix(
                    one_edge_routing_params, encoded_graph, static_metadata)

                # Start state is always state 0.
                start_machine_states = jnp.broadcast_to(
                    (jnp.arange(num_fsm_states) == 0)[None, :],
                    (num_nodes, num_fsm_states))

                return automaton_builder.all_nodes_absorbing_solve(
                    builder=builder,
                    transition_matrix=transition_matrix,
                    variant_weights=one_edge_variant_weights,
                    start_machine_states=start_machine_states,
                    steps=steps,
                    backtrack_fails_prob=backtrack_fails_prob)

            absorbing_solution = jax.vmap(
                solve_one,
                in_axes=(0, variant_out_edge_axis),
            )(stacked_routing_params, variant_weights)

            # Rescale the logits.
            logits = model_util.safe_logit(absorbing_solution)
            if logit_scaling == "learned":
                # Learned scaling and shifting.
                logits = model_util.ScaleAndShift(logits)
            elif logit_scaling == "dynamic":
                raise NotImplementedError(
                    "Dynamic scaling not implemented for unshared")

        logits = side_outputs.encourage_discrete_logits(
            logits, distribution_type="binary", name="edge_logits")
        result = jax.nn.sigmoid(logits)
        return result
    def apply(
        self,
        graph_context,
        node_embeddings,
        edge_embeddings,
        forward_edge_types=gin.REQUIRED,
        reverse_edge_types=gin.REQUIRED,
        walk_length_log2=gin.REQUIRED,
    ):
        """Modifies edge embeddings using a uniform random walk.

    Uses an efficient repeated-squaring technique to compute the absorbing
    distribution.

    Args:
      graph_context: Input graph for this example.
      node_embeddings: Current node embeddings, as <float32[num_nodes,
        node_embedding_dim]>
      edge_embeddings: Current edge embeddings, as <float32[num_nodes,
        num_nodes, edge_embedding_dim]>
      forward_edge_types: Edge types to use in the forward direction. As a list
        of lists to allow configuring groups of edges in config files; this will
        be flattened before use.
      reverse_edge_types: Edge types to use in the reverse direction. Note that
        reversed edge types are given a separate embedding from forward edge
        types; undirected edges should be represented by adding two edges in
        opposite directions and then only using `forward_edge_types`. Also a
        list of lists, as above.
      walk_length_log2: Base-2 logarithm of maximum walk length; this determines
        how many times we will square the transition matrix (doubling the walk
        length).

    Returns:
      New node and edge embeddings. Node embeddings will not be modified. Edge
      embeddings will be modified by adding a new edge type (either embedded or
      concatenated based on graph_context.edges_are_embedded).
    """
        num_nodes = node_embeddings.shape[0]

        # pylint: disable=g-complex-comprehension
        forward_edge_type_indices = [
            graph_context.edge_types_to_indices[type_str]
            for group in forward_edge_types for type_str in group
        ]
        reverse_edge_type_indices = [
            graph_context.edge_types_to_indices[type_str]
            for group in reverse_edge_types for type_str in group
        ]
        # pylint: enable=g-complex-comprehension

        adjacency = graph_layers.edge_mask(
            edges=graph_context.bundle.edges,
            num_nodes=num_nodes,
            num_edge_types=len(graph_context.edge_types_to_indices),
            forward_edge_type_indices=forward_edge_type_indices,
            reverse_edge_type_indices=reverse_edge_type_indices)
        adjacency = jnp.maximum(adjacency, jnp.eye(num_nodes))

        absorbing_logit = self.param(
            "absorbing_logit",
            shape=(),
            initializer=lambda *_: jax.scipy.special.logit(0.1))

        absorbing_prob = jax.nn.sigmoid(absorbing_logit)
        nonabsorbing_prob = jax.nn.sigmoid(-absorbing_logit)
        walk_matrix = nonabsorbing_prob * adjacency / jnp.sum(
            adjacency, axis=1, keepdims=True)

        # A, I
        # A^2, A + I
        # (A^2)^2 = A^4, (A + I)A^2 + (A + I) = A^3 + A^2 + A + I
        # ...

        def step(state, _):
            nth_power, nth_partial_sum = state
            return (nth_power @ nth_power,
                    nth_power @ nth_partial_sum + nth_partial_sum), None

        (_, partial_sum), _ = jax.lax.scan(step,
                                           (walk_matrix, jnp.eye(num_nodes)),
                                           None,
                                           length=walk_length_log2)

        approx_visits = absorbing_prob * partial_sum
        logits = model_util.safe_logit(approx_visits)
        logits = model_util.ScaleAndShift(logits)
        edge_weights = jax.nn.sigmoid(logits)

        return (node_embeddings,
                _add_edges(edge_embeddings, edge_weights[:, :, None],
                           graph_context.edges_are_embedded))