Beispiel #1
0
def reweight_cl(weights, ngals, cl_in):
    """
    """
    # assert len(weights) == len(ngals)
    nprobe = weights.shape[0]
    offset = 0
    w = [None] * nprobe
    nzbin = np.array([len(W) for W in weights])
    nout = np.sum(nzbin * (1 + np.arange(nprobe)))
    cl_out = [None] * nout
    for i1 in range(nprobe):
        nrow = len(weights[i1])
        rowstep = nprobe - i1
        for i2 in range(i1, nprobe):
            #assert weights[i2].shape[1] == len(ngals[i2])
            W = weights[i2] * ngals[i2]
            W /= jnp.sum(W, axis=1, keepdims=True)
            w[i2] = jnp.nan_to_num(W, posinf=1e30, neginf=-1e30, copy=False)
            cl = jnp.einsum('ip,spqk,jq->sijk', w[i1], cl_in[i2][i1], w[i2])
            for j in range(nrow):
                start = j if i1 == i2 else 0
                cl_out[offset + j * rowstep + i2 - i1] = cl[:, j, start:]
        offset += nrow * rowstep
    return jnp.concatenate(cl_out, axis=1)
def extract_outputs_and_targets(
    model,
    padded_example_and_rng,
    target_edge_index,
    num_edge_types,
):
    """Extract model outputs and targets for an example.

  Args:
    model: Model to run on the example.
    padded_example_and_rng: Example to extract targets from, with RNG.
    target_edge_index: Index of the target edge type.
    num_edge_types: How many edge types there are.

  Returns:
    Tuple (output_logits, targets, valid_mask, num_nodes, captured)
  """
    padded_example, rng = padded_example_and_rng
    # Run the model.
    with side_outputs.collect_side_outputs() as captured:
        with flax.nn.stochastic(rng):
            output_logits = model(padded_example)
    # Extract targets.
    targets = padded_example.edges.apply_add(in_array=(
        jnp.arange(num_edge_types) == target_edge_index).astype("int32"),
                                             out_array=jnp.zeros(
                                                 output_logits.shape,
                                                 dtype="int32")).astype("bool")
    targets = preprocess_targets(targets)
    # Compute valid mask for outputs and targets.
    max_num_nodes = output_logits.shape[0]
    num_nodes = padded_example.graph_metadata.num_nodes
    valid_nodes = jnp.arange(max_num_nodes) < num_nodes
    valid_nodes_float = valid_nodes.astype("float32")
    valid_mask = jnp.einsum("i,j->ij", valid_nodes_float, valid_nodes_float)
    return output_logits, targets, valid_mask, num_nodes, captured
Beispiel #3
0
    def __call__(self, inputs):
        """Applies layer to input.

    Args:
      inputs: jnp.ndarray of shape [ens_size * batch_size, ..., input_dim].

    Returns:
      jnp.ndarray of shape [ens_size * batch_size, ..., features].
    """
        dtype = self.dtype or inputs.dtype
        inputs = jnp.asarray(inputs, dtype)
        input_dim = inputs.shape[-1]

        kernel = self.param('kernel', self.kernel_init,
                            (input_dim, self.features), dtype)
        alpha = self.param('fast_weight_alpha', self.alpha_init,
                           (self.ens_size, input_dim), dtype)
        gamma = self.param('fast_weight_gamma', self.gamma_init,
                           (self.ens_size, self.features), dtype)

        inputs_shape = inputs.shape
        inputs = jnp.reshape(inputs, (self.ens_size, -1) + inputs_shape[1:])
        outputs = jnp.einsum('E...C,EC,CD,ED->E...D', inputs, alpha, kernel,
                             gamma)

        if self.use_ensemble_bias:
            bias = self.param('bias', self.bias_init,
                              (self.ens_size, self.features), dtype)
            bias_shape = (self.ens_size, ) + (1, ) * (outputs.ndim - 2) + (
                self.features, )
            outputs = outputs + jnp.reshape(bias, bias_shape)

        if self.activation is not None:
            outputs = self.activation(outputs)  # pylint: disable=not-callable

        return jnp.reshape(outputs, inputs_shape[:-1] + (self.features, ))
    def test_einsum_kpmurphy_example(self):
        # code from an email with @murphyk
        N, C, D, K, T = 2, 3, 4, 5, 6
        r = self.rng()
        S = r.randn(N, T, K)
        W = r.randn(K, D)
        V = r.randn(D, C)
        L = np.zeros((N, C))
        for n in range(N):
            for c in range(C):
                s = 0
                for d in range(D):
                    for k in range(K):
                        for t in range(T):
                            s += S[n, t, k] * W[k, d] * V[d, c]
                L[n, c] = s

        path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
        rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
        self.assertAllClose(L,
                            jnp.einsum('ntk,kd,dc->nc', S, W, V,
                                       optimize=path),
                            check_dtypes=False,
                            rtol=rtol)
Beispiel #5
0
    def __filter_step(self, state, obs_t):
        nsamples = self.nsamples
        indices = jnp.arange(nsamples)
        zt_rvs, key_t = state

        key_t, key_reindex, key_next = random.split(key_t, 3)
        # 1. Draw new points from the dynamic model
        zt_rvs = random.multivariate_normal(key_t, self.fz(zt_rvs), self.Q(zt_rvs))

        # 2. Calculate unnormalised weights
        xt_rvs = self.fx(zt_rvs)
        weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, self.R(zt_rvs, obs_t))

        # 3. Resampling
        pi = random.choice(key_reindex, indices,
                           p=weights_t, shape=(nsamples,))
        zt_rvs = zt_rvs[pi, ...]
        weights_t = jnp.ones(nsamples) / nsamples

        # 4. Compute latent-state estimate,
        #    Set next covariance state matrix
        mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t)

        return (zt_rvs, key_next), mu_t
    def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
        """Connects Module.

    Args:
      inputs: Tensor of shape [..., num_channel]

    Returns:
      output of shape [..., num_output]
    """
        n_channels = int(inputs.shape[-1])

        weight_shape = [n_channels, self.num_output]
        if self.initializer == 'linear':
            weight_init = hk.initializers.VarianceScaling(mode='fan_in',
                                                          scale=1.)
        elif self.initializer == 'relu':
            weight_init = hk.initializers.VarianceScaling(mode='fan_in',
                                                          scale=2.)
        elif self.initializer == 'zeros':
            weight_init = hk.initializers.Constant(0.0)

        weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
                                   weight_init)

        # this is equivalent to einsum('...c,cd->...d', inputs, weights)
        # but turns out to be slightly faster
        inputs = jnp.swapaxes(inputs, -1, -2)
        output = jnp.einsum('...cb,cd->...db', inputs, weights)
        output = jnp.swapaxes(output, -1, -2)

        if self.use_bias:
            bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
                                    hk.initializers.Constant(self.bias_init))
            output += bias

        return output
Beispiel #7
0
    def __call__(self, row, col, features):
        attn = hk.Linear(self.num_heads * self.attention_size)(features)
        attn = jnp.reshape(attn, (-1, self.num_heads, self.attention_size))
        attn = jax.nn.sigmoid(jnp.einsum("eha,eha->eh", attn[row], attn[col]))
        epsilon = jax.nn.sigmoid(
            hk.get_parameter(
                "epsilon_sig_inv",
                shape=(self.num_heads, ),
                dtype=features.dtype,
                init=lambda shape, dtype: jnp.full(shape, -2.0, dtype=dtype),
            ))
        features = hk.Linear(self.num_heads * self.out_size)(features)
        features = jnp.reshape(features, (-1, self.num_heads, self.out_size))

        def propagate(features, attn, eps, row, col):
            adj = COO((attn, row, col), shape=(features.shape[0], ) * 2)
            propagator = _get_propagator(adj, eps, self.tol)
            return propagator @ features
            # shifted_lap = _get_shifted_laplacian(adj, eps)
            # return jax.scipy.sparse.linalg.cg(lambda x: shifted_lap @ x, features)

        # out_ax = 0
        # out = jax.vmap(propagate, in_axes=(1, 1, 0, None, None), out_axes=out_ax)(
        #     features, attn, epsilon, row, col
        # )
        def unstack(x, axis):
            return [x.take(i, axis=axis) for i in range(x.shape[axis])]

        features = unstack(features, axis=1)
        attn = unstack(attn, axis=1)
        eps = unstack(epsilon, axis=0)
        out = [
            propagate(f, a, e, row, col)
            for (f, a, e) in zip(features, attn, eps)
        ]
        return sum(out) / len(out)
Beispiel #8
0
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
    ):
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        causal_attention_mask = None
        if self.causal:
            query_length, key_length = query.shape[1], key.shape[1]
            causal_attention_mask = self.causal_mask[:, :,
                                                     key_length - query_length:
                                                     key_length, :key_length]

        if attention_mask is not None and causal_attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_mask = combine_masks(attention_mask,
                                           causal_attention_mask,
                                           dtype="i4")
        elif causal_attention_mask is not None:
            attention_mask = causal_attention_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Beispiel #9
0
    def __call__(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        sincos = jnp.take(self.embed_positions, position_ids, axis=0)
        sincos = jnp.split(sincos, 2, axis=-1)
        if self.rotary_dim is not None:
            k_rot = key[:, :, :, :self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim:]

            q_rot = query[:, :, :, :self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim:]

            k_rot = apply_rotary_pos_emb(k_rot, sincos)
            q_rot = apply_rotary_pos_emb(q_rot, sincos)

            key = jnp.concatenate([k_rot, k_pass], axis=-1)
            query = jnp.concatenate([q_rot, q_pass], axis=-1)
        else:
            key = apply_rotary_pos_emb(key, sincos)
            query = apply_rotary_pos_emb(query, sincos)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
        )

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Beispiel #10
0
 def _check(self, s, *ops):
     a = onp.einsum(s, *ops)
     b = np.einsum(s, *ops)
     self.assertAllClose(a, b, atol=1e-4, rtol=1e-4, check_dtypes=True)
Beispiel #11
0
def bulk_modulus(elastic_tensor):
  return np.einsum('iijj->', elastic_tensor) / elastic_tensor.shape[0] ** 2
Beispiel #12
0
def cosine_similarity(a: Array, b: Array) -> Array:
  """Computes batched cosine similarity between two 2D arrays."""
  a_norm = jnp.linalg.norm(a, axis=-1)
  b_norm = jnp.linalg.norm(b, axis=-1)
  dot = jnp.einsum('bd,bd->b', a, b)
  return dot / (_SMALL + a_norm * b_norm)
Beispiel #13
0
init_state = (mu_t, tau_t)
xs = (Phi, y)

adf_loop = partial(adf_step, q=q, lbound=lbound, ubound=ubound)
(mu_t, tau_t), (mu_t_hist, tau_t_hist) = jax.lax.scan(adf_loop, init_state, xs)

# ** Estimating posterior predictive distribution **
xmin, ymin = X.min(axis=0) - 0.1
xmax, ymax = X.max(axis=0) + 0.1
step = 0.1
Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
_, nx, ny = Xspace.shape
Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace])

# MCMC posterior predictive distribution
Z_mcmc = sigmoid(jnp.einsum("mij,sm->sij", Phispace, chains))
Z_mcmc = Z_mcmc.mean(axis=0)
# Laplace posterior predictive distribution
key = random.PRNGKey(314)
laplace_samples = random.multivariate_normal(key, w_map, SN, (n_samples, ))
Z_laplace = sigmoid(jnp.einsum("mij,sm->sij", Phispace, laplace_samples))
Z_laplace = Z_laplace.mean(axis=0)
# ADF posterior predictive distribution
adf_samples = random.multivariate_normal(key, mu_t, jnp.diag(tau_t),
                                         (n_samples, ))
Z_adf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, adf_samples))
Z_adf = Z_adf.mean(axis=0)

# ** Plotting predictive distribution **
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
Beispiel #14
0
    def __call__(self, input_qkv):
        cfg = self.config
        cfg.max_len % cfg.max_seg_len == 0
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        query, key, value, head_dim = get_qkv(cfg, input_qkv)

        num_seg = cfg.max_len // cfg.max_seg_len
        cur_query = query.reshape(
            [-1, cfg.max_seg_len, query.shape[-2], query.shape[-1]])
        merged_query = jnp.max(cur_query, axis=1,
                               keepdims=True) * jnp.sqrt(head_dim)
        cur_key = key.reshape(
            [-1, cfg.max_seg_len, key.shape[-2], key.shape[-1]])
        cur_value = value.reshape(
            [-1, cfg.max_seg_len, value.shape[-2], value.shape[-1]])
        dropout_rng = None
        if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')
        s = dot_product_attention(merged_query,
                                  cur_key,
                                  cur_value,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=cfg.attention_dropout_rate,
                                  broadcast_dropout=False,
                                  deterministic=cfg.deterministic,
                                  dtype=cfg.dtype)
        span_val = jnp.reshape(s, [bsize, -1, s.shape[-2], s.shape[-1]])
        span_key = jnp.max(cur_key, axis=1, keepdims=True)
        # (bsize, n_seg, n_head, dim_per_head)
        span_key = jnp.reshape(
            span_key, [bsize, -1, span_key.shape[-2], span_key.shape[-1]])

        local_mask = make_causal_mask(cur_query,
                                      length_axis=1).transpose([0, 2, 1, 3])
        local_bias = lax.select(
            local_mask > 0,
            jnp.full(local_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(local_mask.shape, -1e10).astype(cfg.dtype))
        # (bsize * n_seg, seg_len, n_head, seg_len)
        local_logits = jnp.einsum('...qhd,...khd->...qhk', cur_query,
                                  cur_key) + local_bias
        local_logits = jnp.reshape(local_logits,
                                   [bsize, -1, cfg.num_heads, cfg.max_seg_len])
        idx = jnp.broadcast_to(jnp.arange(span_key.shape[1], dtype=jnp.int32),
                               span_key.shape[:2])
        prev_mask = nn.make_attention_mask(idx,
                                           idx,
                                           jnp.greater,
                                           extra_batch_dims=0,
                                           dtype=jnp.float32).transpose(
                                               [0, 2, 1, 3])
        prev_mask = jnp.repeat(prev_mask, cfg.max_seg_len, axis=-3)
        prev_bias = lax.select(
            prev_mask > 0,
            jnp.full(prev_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(prev_mask.shape, -1e10).astype(cfg.dtype))
        # (bsize, max_len, n_head, num_segs)
        prev_logits = jnp.einsum('...qhd,...khd->...qhk', query,
                                 span_key) + prev_bias
        joint_logits = jnp.concatenate((local_logits, prev_logits), axis=-1)
        # (bsize x max_len,  n_head, seg_len + num_segs)
        attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype)
        local_att, prev_att = jnp.split(attn_weights, [cfg.max_seg_len],
                                        axis=-1)
        local_att = local_att.reshape(
            [bsize * num_seg, cfg.max_seg_len, cfg.num_heads, cfg.max_seg_len])
        local_merged = jnp.einsum('...qhk,...khd->...qhd', local_att,
                                  cur_value)
        prev_merged = jnp.einsum('...qhk,...khd->...qhd', prev_att, span_val)
        joint_merged = jnp.reshape(local_merged,
                                   prev_merged.shape) + prev_merged
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
                         kernel_init=cfg.kernel_init,
                         bias_init=cfg.bias_init,
                         use_bias=False,
                         dtype=cfg.dtype)(joint_merged)
        return x
Beispiel #15
0
    def __call__(self, input_qkv):
        cfg = self.config
        log_len = log_2_ceil(cfg.max_len - 1)
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        query, key, value, head_dim = get_qkv(cfg, input_qkv)

        joint_logits = []
        list_vals = []
        for l in range(log_len):
            ctx_len = 2**l
            last_pos = cfg.max_len - cfg.max_len % ctx_len
            num_ctx = cfg.max_len // ctx_len

            if l == 0:
                span_key = jnp.reshape(key, [-1, 1, cfg.num_heads, head_dim])
                span_val = value.reshape(span_key.shape)
                self_logits = jnp.expand_dims(jnp.sum(query * key, axis=-1),
                                              -1)
                joint_logits.append(self_logits)
            else:
                left_query = query[:, :last_pos, :, :].reshape(
                    [-1, ctx_len, cfg.num_heads, head_dim])
                span_query = jnp.max(left_query, axis=1, keepdims=True)
                left_key = key[:, :last_pos, :, :].reshape(left_query.shape)
                left_val = value[:, :last_pos, :, :].reshape(left_query.shape)
                span_val = dot_product_attention(
                    span_query * jnp.sqrt(head_dim),
                    left_key,
                    left_val,
                    dropout_rng=self.get_dropout_png(cfg),
                    dropout_rate=cfg.attention_dropout_rate,
                    broadcast_dropout=False,
                    deterministic=cfg.deterministic,
                    dtype=cfg.dtype)
                span_key = jnp.max(left_key, axis=1, keepdims=True)
            rolled_q = jnp.roll(query, -ctx_len,
                                axis=1)[:, :last_pos, :, :].reshape(
                                    [-1, ctx_len, cfg.num_heads, head_dim])

            rolled_mask = jnp.concatenate(
                [(jnp.arange(cfg.max_len - ctx_len) // ctx_len) % 2,
                 jnp.ones(last_pos + ctx_len - cfg.max_len, dtype=jnp.int32)],
                axis=0)
            rolled_mask = jnp.reshape(rolled_mask, [1, -1, 1, 1])
            rolled_logits = jnp.einsum('...qhd,...khd->...qhk', rolled_q,
                                       span_key)
            # bsize, last_pos, h, 1
            rolled_logits = jnp.reshape(
                rolled_logits, [bsize, -1, cfg.num_heads, 1
                                ]) + rolled_mask.astype(rolled_q.dtype) * -1e9
            orig_logits = jnp.pad(rolled_logits, [(0, 0),
                                                  (0, cfg.max_len - last_pos),
                                                  (0, 0), (0, 0)],
                                  constant_values=-1e9)
            orig_logits = jnp.roll(orig_logits, ctx_len, axis=1)
            joint_logits.append(orig_logits)
            list_vals.append(span_val)
        joint_logits = jnp.concatenate(joint_logits, axis=-1)
        attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype)
        local_weights = jnp.split(attn_weights, log_len + 1, axis=-1)
        local_weighted_sums = []
        joint_merged = local_weights[0] * value
        for l in range(log_len):
            ctx_len = 2**l
            last_pos = cfg.max_len - cfg.max_len % ctx_len
            num_ctx = cfg.max_len // ctx_len

            rolled_w = jnp.roll(local_weights[l + 1], -ctx_len,
                                axis=1)[:, :last_pos, :, :].reshape(
                                    bsize * num_ctx, ctx_len, cfg.num_heads, 1)
            rolled_v = jnp.reshape(rolled_w * list_vals[l],
                                   [bsize, -1, cfg.num_heads, head_dim])
            rolled_v = jnp.pad(rolled_v, [(0, 0), (0, cfg.max_len - last_pos),
                                          (0, 0), (0, 0)])
            orig_v = jnp.roll(rolled_v, ctx_len, axis=1)
            joint_merged = joint_merged + orig_v
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
                         kernel_init=cfg.kernel_init,
                         bias_init=cfg.bias_init,
                         use_bias=False,
                         dtype=cfg.dtype)(joint_merged)
        return x
def roll_out_transitions(
    builder,
    transition_matrix,
    variant_weights, start_machine_state,
    node_index, steps, rng,
    max_possible_transitions):
  """Roll out transitions for a transition matrix.

  Args:
    builder: The automaton builder associated with this Markov chain.
    transition_matrix: The per-variant transition matrix for this Markov chain.
    variant_weights: Weights to assign to each routing variant for each node, as
      a <float[num_nodes, num_variants]> array (which should sum to 1 across the
      last axis).
    start_machine_state: Initial machine state distribution for the starting
      nodes, as a <float[num_fsm_states]> array (which should sum to 1).
    node_index: Initial node index where the solve should start, as a int32
      scalar.
    steps: How many steps to unroll.
    rng: Random number generator used to sample.
    max_possible_transitions: Max transitions to truncate to.

  Returns:
    Final RollOutState.
  """
  (variants, nodes, fsm_states, in_tagged_nodes,
   _) = transition_matrix.initial_to_in_tagged.shape
  itn_states = in_tagged_nodes * fsm_states

  def add_special_dests(in_tagged_dests):
    return jnp.concatenate(
        [in_tagged_dests,
         jnp.arange(itn_states, itn_states + 3)])

  # Collapse our transition matrix into a more useful form:
  # - Combine nodes and FSM states
  # - Extract a subset of possible destinations
  # - Combine special and move actions into a single probability table
  k = min(max_possible_transitions, in_tagged_nodes) * fsm_states
  initial_to_in_tagged_combo = transition_matrix.initial_to_in_tagged.reshape(
      (variants, nodes, fsm_states, itn_states))
  _, initial_to_in_tagged_dests = jax.lax.top_k(
      jnp.sum(initial_to_in_tagged_combo, axis=(0, 2)), k=k)
  initial_to_in_tagged_probs = jax.vmap(
      lambda c, i: c[:, :, i], in_axes=(1, 0),
      out_axes=1)(initial_to_in_tagged_combo, initial_to_in_tagged_dests)
  initial_probs = jnp.concatenate(
      [initial_to_in_tagged_probs, transition_matrix.initial_to_special], -1)
  initial_dests = jax.vmap(add_special_dests)(initial_to_in_tagged_dests)

  in_tagged_to_in_tagged_combo = transition_matrix.in_tagged_to_in_tagged.reshape(
      (variants, itn_states, itn_states))
  _, in_tagged_to_in_tagged_dests = jax.lax.top_k(
      jnp.sum(in_tagged_to_in_tagged_combo, axis=0), k=k)
  in_tagged_to_in_tagged_probs = jax.vmap(
      lambda c, i: c[:, i], in_axes=(1, 0),
      out_axes=1)(in_tagged_to_in_tagged_combo, in_tagged_to_in_tagged_dests)
  in_tagged_probs = jnp.concatenate([
      in_tagged_to_in_tagged_probs,
      transition_matrix.in_tagged_to_special.reshape((variants, itn_states, 3))
  ], -1)
  in_tagged_dests = jax.vmap(add_special_dests)(in_tagged_to_in_tagged_dests)

  start_node_variant_weights = variant_weights[node_index]
  initial_probs_from_here = jnp.einsum("v,s,vsj->j", start_node_variant_weights,
                                       start_machine_state,
                                       initial_probs[:, node_index, :, :])
  initial_dests_from_here = initial_dests[node_index]

  per_itn_variants = variant_weights[transition_matrix.in_tagged_node_indices]

  # Set up the initial state
  initial_state = RollOutState(rng=rng)

  @jax.remat
  def scan_body(state, ignored_input):
    assert ignored_input is None
    rng, key = jax.random.split(state.rng)

    # If we are in the initial state, sample the initial transition.
    def at_initial_info():
      return initial_probs_from_here, initial_dests_from_here

    # If we are in a normal state, sample the next action.
    def at_normal_info():
      cur_variant_weights = per_itn_variants[state.itn_state_index //
                                             fsm_states]
      next_step_probs = jnp.einsum("v,vj->j", cur_variant_weights,
                                   in_tagged_probs[:, state.itn_state_index, :])

      return next_step_probs, in_tagged_dests[state.itn_state_index, :]

    # Figure out which to do, and sample from the appropriate probabilities
    step_probs, step_dests = jax.tree_multimap(
        functools.partial(jnp.where, state.at_initial), at_initial_info(),
        at_normal_info())

    next_idx = jax.random.categorical(key, jnp.log(step_probs))
    log_prob = jnp.log(step_probs[next_idx])
    dest = step_dests[next_idx]
    did_special = dest >= itn_states

    state_after_move = RollOutState(
        at_initial=False,
        succeeded=False,
        failed=False,
        special_from_initial=False,
        itn_state_index=dest,
        final_node=None,
        log_prob=state.log_prob + log_prob,
        rng=rng)

    special_idx = dest - itn_states
    state_after_special = RollOutState(
        at_initial=(special_idx == builder.special_actions.index(
            automaton_builder.SpecialActions.BACKTRACK)),
        succeeded=(special_idx == builder.special_actions.index(
            automaton_builder.SpecialActions.FINISH)),
        failed=(special_idx == builder.special_actions.index(
            automaton_builder.SpecialActions.FAIL)),
        special_from_initial=state.at_initial,
        itn_state_index=state.itn_state_index,
        final_node=None,
        log_prob=state.log_prob + log_prob,
        rng=rng)

    # Choose the right branch to take
    def choose(move, special, done):
      return jnp.where(state.succeeded | state.failed, done,
                       jnp.where(did_special, special, move))

    new_state = jax.tree_multimap(choose, state_after_move, state_after_special,
                                  state)

    return new_state, None

  final_state, _ = jax.lax.scan(scan_body, initial_state, None, length=steps)
  final_node = jnp.where(
      final_state.special_from_initial, node_index,
      transition_matrix.in_tagged_node_indices[final_state.itn_state_index //
                                               fsm_states])
  final_state = dataclasses.replace(final_state, final_node=final_node)
  return final_state
Beispiel #17
0
 def cell_fn(theta, state0, inputs_t):
     del state0
     y = jnp.einsum('x,xy->y', inputs_t.x, theta.proj)
     return NestedMap(y=y)
Beispiel #18
0
 def batched_dot(a, b):
     if a.shape[0] != b.shape[0]:
         raise TypeError("Shapes must match in the 0-th dimension")
     if a.ndim == 2 or b.ndim == 2:
         return jnp.einsum("n...j,nj...->n...", a, b)
     return jnp.einsum("nij,njk->nik", a, b)
Beispiel #19
0
                  poly_axes=[0, 0]),

    _make_harness("dynamic_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("dynamic_update_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("einsum", "0",
                  lambda x: jnp.einsum("...i->...", x),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("einsum", "1",
                  lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y),
                  [RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)],
                  poly_axes=[0, 0]),

    _make_harness("einsum", "2",
                  lambda x, y: jnp.einsum("...ij,jk->...ik", x, y),
                  [RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)],
                  poly_axes=[0, None]),

    _make_harness("einsum", "3",
                  # Reduced dimension is polymorphic
Beispiel #20
0
def oei_arrays(geom, basis, charges):
    """
    Build one electron integral arrays (overlap, kinetic, and potential integrals)
    """
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    nprim = coeffs.shape[0]
    max_am = jnp.max(ams)
    A_vals = jnp.zeros(2 * max_am + 1)

    # Save various AM distributions for indexing
    # Obtain all possible primitive quartet index combinations
    primitive_duets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim))

    with loops.Scope() as s:
        s.oei = jnp.zeros((3, nbf, nbf))
        s.a = 0  # center A angular momentum iterator
        s.b = 0  # center B angular momentum iterator

        for prim_duet in s.range(primitive_duets.shape[0]):
            p1, p2 = primitive_duets[prim_duet]
            coef = coeffs[p1] * coeffs[p2]
            aa, bb = exps[p1], exps[p2]
            atom1, atom2 = atoms[p1], atoms[p2]
            am1, am2 = ams[p1], ams[p2]
            A, B = geom[atom1], geom[atom2]
            ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2]

            gamma = aa + bb
            prefactor = jnp.exp(-aa * bb * jnp.dot(A - B, A - B) / gamma)
            P = (aa * A + bb * B) / gamma
            # Maximum angular momentum: hard coded
            #max_am = 3 # f function support
            # Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi.
            # We need 2+max_am since kinetic requires incrementing angluar momentum by +2
            PA_pow = jnp.power(
                jnp.broadcast_to(P - A, (max_am + 3, 3)).T,
                jnp.arange(max_am + 3))
            PB_pow = jnp.power(
                jnp.broadcast_to(P - B, (max_am + 3, 3)).T,
                jnp.arange(max_am + 3))

            # For potential integrals, we need the difference between
            # the gaussian product center P and ALL atoms in the molecule,
            # and then take all possible powers up to 2*max_am.
            # We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed.
            # The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3
            P_minus_geom = jnp.broadcast_to(P, geom.shape) - geom
            Pgeom_pow = jnp.power(
                jnp.transpose(
                    jnp.broadcast_to(
                        P_minus_geom,
                        (2 * max_am + 1, geom.shape[0], geom.shape[1])),
                    (1, 2, 0)), jnp.arange(2 * max_am + 1))
            # All possible jnp.dot(P-atom,P-atom)
            rcp2 = jnp.einsum('ij,ij->i', P_minus_geom, P_minus_geom)
            # All needed (and unneeded, for am < max_am) boys function evaluations
            boys_arg = jnp.broadcast_to(rcp2 * gamma,
                                        (2 * max_am + 1, geom.shape[0]))
            boys_nu = jnp.tile(jnp.arange(2 * max_am + 1),
                               (geom.shape[0], 1)).T
            boys_eval = boys(boys_nu, boys_arg)

            s.a = 0
            for _ in s.while_range(lambda: s.a < dims[p1]):
                s.b = 0
                for _ in s.while_range(lambda: s.b < dims[p2]):
                    # Gather angular momentum and index
                    la, ma, na = angular_momentum_combinations[s.a + ld1]
                    lb, mb, nb = angular_momentum_combinations[s.b + ld2]
                    # To only create unique indices, need to have separate indices arrays for i and j.
                    i = indices[p1] + s.a
                    j = indices[p2] + s.b
                    # Compute one electron integrals and add to appropriate index
                    overlap_int = overlap(la, ma, na, lb, mb, nb, aa, bb,
                                          PA_pow, PB_pow, prefactor) * coef
                    kinetic_int = kinetic(la, ma, na, lb, mb, nb, aa, bb,
                                          PA_pow, PB_pow, prefactor) * coef
                    potential_int = potential(la, ma, na, lb, mb, nb, aa, bb,
                                              PA_pow, PB_pow, Pgeom_pow,
                                              boys_eval, prefactor, charges,
                                              A_vals) * coef
                    s.oei = jax.ops.index_add(
                        s.oei, ([0, 1, 2], [i, i, i], [j, j, j]),
                        (overlap_int, kinetic_int, potential_int))

                    s.b += 1
                s.a += 1
    S, T, V = s.oei[0], s.oei[1], s.oei[2]
    return S, T, V
Beispiel #21
0
def raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, rng=None):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: (num_rays, num_samples || num_importance, 4) prediction from model
        z_vals: (num_rays, num_samples || num_importance) integration time
        rays_d: (num_rays, 3) direction of each ray
        raw_noise_std: std of noise added for regularization
        white_bkgd: whether to use the alpha channel for white background
        rng: random key
    Returns:
        acc_map: (num_rays) sum of weights along each ray
        depth_map: (num_rays) estimated distance to object
        disp_map: (num_rays) disparity map (inverse of depth map)
        rgb_map: (num_rays, 3) estimated RGB color of a ray
        weights: (num_rays, num_samples || num_importance) weights assigned to each sampled color
    """

    # compute 'distance' (in time) between each integration time along a ray
    dists = z_vals[..., 1:] - z_vals[..., :-1]

    # the 'distance' from the last integration time is infinity
    dists = jnp.concatenate(
        [dists, jnp.broadcast_to([1e10], dists[..., :1].shape)], axis=-1)
    dists = dists.astype(z_vals.dtype)  # [num_rays, num_samples]

    # multiply each distance by the norm of its corresponding direction ray
    # to convert to real world distance (accounts for non-unit directions)
    dists = dists * jnp.linalg.norm(rays_d[..., None, :], axis=-1)

    # extract RGB of each sample position along each ray
    rgb = nn.sigmoid(raw[..., :3])  # [num_rays, num_samples, 3]

    # add noise to predictions for density, can be used to (this value is strictly between [0, 1])
    # regularize network during training (prevents floater artifacts)
    noise = 0.0
    if raw_noise_std > 0.0 and rng is not None:
        noise = random.normal(rng, raw[..., 3].shape) * raw_noise_std

    # predict density of each sample along each ray (alpha channel)
    # higher values imply higher likelihood of being absorbed at this point
    alpha = 1.0 - jnp.exp(-nn.relu(raw[..., 3] + noise) * dists)

    # compute weight for RGB of each sample along each ray
    # cumprod() is used to express the idea of the ray not having reflected up to this sample yet
    # weights = alpha * tf.math.cumprod(1.0 - alpha + 1e-10, axis=-1, exclusive=True)
    alpha_ = jnp.clip(1.0 - alpha, 1e-5, 1.0)
    weights = jnp.concatenate(
        [jnp.ones_like(alpha_[..., :1]), alpha_[..., :-1]], -1)
    weights = alpha * jnp.cumprod(weights, -1)  # [num_rays, num_samples]

    # computed weighted color of each sample along each ray
    rgb_map = jnp.einsum("ij,ijk->ik", weights, rgb)  # [num_rays, 3]

    # estimated depth map is expected distance
    depth_map = jnp.einsum("ij,ij->i", weights, z_vals)  # [num_rays]

    # sum of weights along each ray (this value is in [0, 1] up to numerical error)
    acc_map = jnp.einsum("ij->i", weights)  # [num_rays]

    # disparity map is inverse depth
    i_depth = depth_map / jnp.clip(acc_map, 1e-5)
    disp_map = 1.0 / jnp.clip(i_depth, 1e-5)

    # to composite onto a white background, use the accumulated alpha map
    if white_bkgd:
        rgb_map += 1.0 - acc_map[..., None]

    return {
        "rgb": rgb_map.astype(jnp.float32),
        "disp": disp_map.astype(jnp.float32),
        "acc": acc_map.astype(jnp.float32),
        "depth": depth_map.astype(jnp.float32),
    }, weights
Beispiel #22
0
 def func(x):
     return np.einsum("ij,ij...->...", aux, x)
Beispiel #23
0
def restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, options, deriv_order=0, return_aux_data=False):
    # Load keyword options
    maxit = options['maxit']
    damping = options['damping']
    damp_factor = options['damp_factor']
    spectral_shift = options['spectral_shift']
    convergence = 1e-10

    nelectrons = int(jnp.sum(nuclear_charges)) - charge
    ndocc = nelectrons // 2

    # If we are doing MP2 or CCSD after, might as well use jit-compiled JK-build, since HF will not be memory bottleneck
    if return_aux_data:
        jk_build = jax.jit(jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None)))
    else: 
        jk_build = jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None))

    # Canonical orthogonalization via cholesky decomposition
    S, T, V, G = compute_integrals(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order, options)
    A = cholesky_orthogonalization(S)

    nbf = S.shape[0]

    # For slightly shifting eigenspectrum of transformed Fock for degenerate eigenvalues 
    # (JAX cannot differentiate degenerate eigenvalue eigh) 
    if spectral_shift:
        # Shifting eigenspectrum requires lower convergence.
        convergence = 1e-8 
        fudge = jnp.asarray(np.linspace(0, 1, nbf)) * convergence
        shift = jnp.diag(fudge)
    else:
        shift = jnp.zeros_like(S)

    H = T + V
    Enuc = nuclear_repulsion(geom.reshape(-1,3),nuclear_charges)
    D = jnp.zeros_like(H)
    
    def rhf_iter(F,D):
        E_scf = jnp.einsum('pq,pq->', F + H, D) + Enuc
        Fp = jnp.dot(A.T, jnp.dot(F, A))
        Fp = Fp + shift 
        eps, C2 = jnp.linalg.eigh(Fp)
        C = jnp.dot(A,C2)
        Cocc = C[:, :ndocc]
        D = jnp.dot(Cocc, Cocc.T)
        return E_scf, D, C, eps

    iteration = 0
    E_scf = 1.0
    E_old = 0.0
    Dold = jnp.zeros_like(D)
    dRMS = 1.0

    # Converge according to energy and DIIS residual to ensure eigenvalues and eigenvectors are maximally converged.
    # This is crucial for numerical stability for higher order derivatives of correlated methods.
    while ((abs(E_scf - E_old) > convergence) or (dRMS > convergence)):
        E_old = E_scf * 1
        if damping:
            if iteration < 10:
                D = Dold * damp_factor + D * damp_factor
                Dold = D * 1
        # Build JK matrix: 2 * J - K
        JK = 2 * jk_build(G, D)
        JK -= jk_build(G.transpose((0,2,1,3)), D)
        # Build Fock
        F = H + JK
        # Update convergence error
        if iteration > 1:
            diis_e = jnp.einsum('ij,jk,kl->il', F, D, S) - jnp.einsum('ij,jk,kl->il', S, D, F)
            diis_e = A.dot(diis_e).dot(A)
            dRMS = jnp.mean(diis_e**2)**0.5
        # Compute energy, transform Fock and diagonalize, get new density
        E_scf, D, C, eps = rhf_iter(F,D)
        iteration += 1
        if iteration == maxit:
            break
    print(iteration, " RHF iterations performed")

    # If many orbitals are degenerate, warn that higher order derivatives may be unstable 
    tmp = jnp.round(eps,6)
    ndegen_orbs =  tmp.shape[0] - jnp.unique(tmp).shape[0] 
    if (ndegen_orbs / nbf) > 0.20:
        print("Hartree-Fock warning: More than 20% of orbitals have degeneracies. Higher order derivatives may be unstable due to eigendecomposition AD rule")
    if not return_aux_data:
        return E_scf
    else:
        return E_scf, C, eps, G
Beispiel #24
0
def fixed_pos_embedding(seq, dim):
    inv_freq = 1.0 / (10000**(np.arange(0, dim, 2) / dim))

    sinusoid_inp = np.einsum("i , j -> i j", np.arange(seq), inv_freq)

    return np.sin(sinusoid_inp), np.cos(sinusoid_inp)
Beispiel #25
0
 def f(scale):
   scaled_mat = scale * psd_mat
   chol = jnp.linalg.cholesky(scaled_mat)
   return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2)
Qt = jnp.eye(2) * 0.001
# Observed noise
Rt = jnp.eye(2) * 0.01

mu0 = jnp.array([1, 1])
Sigma0 = jnp.eye(2)

key = random.PRNGKey(314)
kf = lds.ContinuousKalmanFilter(A, C, Qt, Rt, x0, Sigma0)
sample_state, sample_obs, jump = kf.sample(key, x0, T, nsamples)
mu_hist, V_hist, *_ = kf.filter(sample_obs, jump, dt)

step = 0.1
vmin, vmax = -1.5, 1.5 + step
X = np.mgrid[-1:1.5:step, vmin:vmax:step][::-1]
X_dot = jnp.einsum("ij,jxy->ixy", A, X)

fig, ax = plt.subplots()
ax.plot(*sample_state.T, label="state space")
ax.scatter(*sample_obs.T, marker="+", c="tab:green", s=60, label="observations")
ax.scatter(*sample_state[0], c="black", zorder=3)
field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa")
ax.legend()
plt.axis("equal")
ax.set_title("State Space")
pml.savefig("kf-circle-state.pdf")

fig, ax = plt.subplots()
ax.plot(*mu_hist.T, c="tab:orange", label="Filtered")
ax.scatter(*sample_obs.T, marker="+", s=60, c="tab:green", label="observations")
ax.scatter(*mu_hist[0], c="black", zorder=3)
Beispiel #27
0
for n in range(N):
    for c in range(C):
        s = 0
        for d in range(D):
            for k in range(K):
                for t in range(T):
                    s += S[n, t, k] * W[k, d] * V[d, c]
        L[n, c] = s
assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V))

path = np.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V, optimize=path))

import jax.numpy as jnp
path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
assert np.allclose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path))

# Use full student network from KOller and Friedman
str = 'c,dc,gdi,si,lg,jls,hgj->'
K = 5
cptC = np.random.randn(K)
cptD = np.random.randn(K, K)
cptG = np.random.randn(K, K, K)
cptS = np.random.randn(K, K)
cptL = np.random.randn(K, K)
cptJ = np.random.randn(K, K, K)
cptH = np.random.randn(K, K, K)
cpts = [cptC, cptD, cptG, cptS, cptL, cptJ, cptH]
path_info = np.einsum_path(str, *cpts, optimize='optimal')
print(path_info[0]
      )  # 'einsum_path', (0, 1), (0, 5), (0, 4), (0, 3), (0, 2), (0, 1)]
 def body(p, qkv):
   (q, k, v) = qkv
   p += jnp.einsum('...m,...d->...md', k, v, precision=precision)
   X_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision)
   return p, X_slice
Beispiel #29
0
def dot_product_attention(query: Array,
                          key: Array,
                          value: Array,
                          bias: Optional[Array] = None,
                          broadcast_dropout: bool = True,
                          dropout_rng: Optional[PRNGKey] = None,
                          dropout_rate: float = 0.,
                          deterministic: bool = False,
                          dtype: Dtype = jnp.float32,
                          precision: Optional[lax.Precision] = None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights.

  Note: query, key, value needn't have any batch dimensions.

  Args:
    query: queries for calculating attention with shape of
      `[batch..., q_length, num_heads, qk_depth_per_head]`.
    key: keys for calculating attention with shape of
      `[batch..., kv_length, num_heads, qk_depth_per_head]`.
    value: values to be used in attention with shape of
      `[batch..., kv_length, num_heads, v_depth_per_head]`.
    bias: bias for the attention weights. This should be broadcastable to the
      shape: `[batch..., num_heads, q_length, kv_length]`
      This can be used for incorporating causal masks, padding masks,
      proximity bias, etc.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    dtype: the dtype of the computation (default: float32)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[batch..., length, num_heads, v_depth_per_head]`.
  """
    assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
    assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
        "q, k, v batch dims must match.")
    assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
        "q, k, v num_heads must match.")
    assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
    assert query.shape[-1] == key.shape[-1], "q, k depths must match."

    # calculate attention matrix
    depth = query.shape[-1]
    query = query / jnp.sqrt(depth).astype(dtype)
    # attn weight shape is (batch..., num_heads, q_length, kv_length)
    attn_weights = jnp.einsum('...qhd,...khd->...hqk',
                              query,
                              key,
                              precision=precision)

    # apply attention bias: masking, dropout, proximity bias, etc.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    attn_weights = jax.nn.softmax(attn_weights).astype(dtype)

    # apply attention dropout
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        if broadcast_dropout:
            # dropout is broadcast across the batch + head dimensions
            dropout_shape = tuple([1] *
                                  (key.ndim - 2)) + attn_weights.shape[-2:]
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    # return weighted sum over values for each query position
    return jnp.einsum('...hqk,...khd->...qhd',
                      attn_weights,
                      value,
                      precision=precision)
 def body(p, qk):
   q, k = qk
   p += k
   x = jnp.einsum('...m,...m->...', q, p, precision=precision)
   return p, x