def step(node_embeddings, edge_embeddings):
        attn_out = graph_layers.NodeSelfAttention(
            edge_embeddings=edge_embeddings,
            node_embeddings=node_embeddings,
            mask=mask,
            out_dim=node_embeddings.shape[-1],
            name="attend")
        node_embeddings = graph_layers.residual_layer_norm_update(
            node_embeddings, attn_out, name="attend_ln")

        fc_out = jax.nn.relu(
            flax.deprecated.nn.Dense(node_embeddings,
                                     features=node_embeddings.shape[-1],
                                     name="fc_dense"))
        node_embeddings = graph_layers.residual_layer_norm_update(
            node_embeddings, fc_out, name="fc_ln")

        return node_embeddings
    def step(node_embeddings, edge_embeddings):
        nri_activations = graph_layers.NRIEdgeLayer(
            edge_embeddings=edge_embeddings,
            node_embeddings=node_embeddings,
            mask=mask,
            message_passing=True,
            name="nri_message_passing")

        for i, dim in enumerate([*mlp_etov_dims, node_embeddings.shape[-1]]):
            nri_activations = flax.deprecated.nn.Dense(nri_activations,
                                                       features=dim,
                                                       name=f"fc{i}")
            nri_activations = jax.nn.relu(nri_activations)

        if with_residual_layer_norm:
            node_embeddings = graph_layers.residual_layer_norm_update(
                node_embeddings, nri_activations, name="layer_norm_update")
        else:
            node_embeddings = nri_activations

        return node_embeddings