Esempio n. 1
def reweight_cl(weights, ngals, cl_in):
    # assert len(weights) == len(ngals)
    nprobe = weights.shape[0]
    offset = 0
    w = [None] * nprobe
    nzbin = np.array([len(W) for W in weights])
    nout = np.sum(nzbin * (1 + np.arange(nprobe)))
    cl_out = [None] * nout
    for i1 in range(nprobe):
        nrow = len(weights[i1])
        rowstep = nprobe - i1
        for i2 in range(i1, nprobe):
            #assert weights[i2].shape[1] == len(ngals[i2])
            W = weights[i2] * ngals[i2]
            W /= jnp.sum(W, axis=1, keepdims=True)
            w[i2] = jnp.nan_to_num(W, posinf=1e30, neginf=-1e30, copy=False)
            cl = jnp.einsum('ip,spqk,jq->sijk', w[i1], cl_in[i2][i1], w[i2])
            for j in range(nrow):
                start = j if i1 == i2 else 0
                cl_out[offset + j * rowstep + i2 - i1] = cl[:, j, start:]
        offset += nrow * rowstep
    return jnp.concatenate(cl_out, axis=1)
def extract_outputs_and_targets(
    """Extract model outputs and targets for an example.

    model: Model to run on the example.
    padded_example_and_rng: Example to extract targets from, with RNG.
    target_edge_index: Index of the target edge type.
    num_edge_types: How many edge types there are.

    Tuple (output_logits, targets, valid_mask, num_nodes, captured)
    padded_example, rng = padded_example_and_rng
    # Run the model.
    with side_outputs.collect_side_outputs() as captured:
        with flax.nn.stochastic(rng):
            output_logits = model(padded_example)
    # Extract targets.
    targets = padded_example.edges.apply_add(in_array=(
        jnp.arange(num_edge_types) == target_edge_index).astype("int32"),
    targets = preprocess_targets(targets)
    # Compute valid mask for outputs and targets.
    max_num_nodes = output_logits.shape[0]
    num_nodes = padded_example.graph_metadata.num_nodes
    valid_nodes = jnp.arange(max_num_nodes) < num_nodes
    valid_nodes_float = valid_nodes.astype("float32")
    valid_mask = jnp.einsum("i,j->ij", valid_nodes_float, valid_nodes_float)
    return output_logits, targets, valid_mask, num_nodes, captured
Esempio n. 3
    def __call__(self, inputs):
        """Applies layer to input.

      inputs: jnp.ndarray of shape [ens_size * batch_size, ..., input_dim].

      jnp.ndarray of shape [ens_size * batch_size, ..., features].
        dtype = self.dtype or inputs.dtype
        inputs = jnp.asarray(inputs, dtype)
        input_dim = inputs.shape[-1]

        kernel = self.param('kernel', self.kernel_init,
                            (input_dim, self.features), dtype)
        alpha = self.param('fast_weight_alpha', self.alpha_init,
                           (self.ens_size, input_dim), dtype)
        gamma = self.param('fast_weight_gamma', self.gamma_init,
                           (self.ens_size, self.features), dtype)

        inputs_shape = inputs.shape
        inputs = jnp.reshape(inputs, (self.ens_size, -1) + inputs_shape[1:])
        outputs = jnp.einsum('E...C,EC,CD,ED->E...D', inputs, alpha, kernel,

        if self.use_ensemble_bias:
            bias = self.param('bias', self.bias_init,
                              (self.ens_size, self.features), dtype)
            bias_shape = (self.ens_size, ) + (1, ) * (outputs.ndim - 2) + (
                self.features, )
            outputs = outputs + jnp.reshape(bias, bias_shape)

        if self.activation is not None:
            outputs = self.activation(outputs)  # pylint: disable=not-callable

        return jnp.reshape(outputs, inputs_shape[:-1] + (self.features, ))
    def test_einsum_kpmurphy_example(self):
        # code from an email with @murphyk
        N, C, D, K, T = 2, 3, 4, 5, 6
        r = self.rng()
        S = r.randn(N, T, K)
        W = r.randn(K, D)
        V = r.randn(D, C)
        L = np.zeros((N, C))
        for n in range(N):
            for c in range(C):
                s = 0
                for d in range(D):
                    for k in range(K):
                        for t in range(T):
                            s += S[n, t, k] * W[k, d] * V[d, c]
                L[n, c] = s

        path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
        rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
                            jnp.einsum('ntk,kd,dc->nc', S, W, V,
Esempio n. 5
    def __filter_step(self, state, obs_t):
        nsamples = self.nsamples
        indices = jnp.arange(nsamples)
        zt_rvs, key_t = state

        key_t, key_reindex, key_next = random.split(key_t, 3)
        # 1. Draw new points from the dynamic model
        zt_rvs = random.multivariate_normal(key_t, self.fz(zt_rvs), self.Q(zt_rvs))

        # 2. Calculate unnormalised weights
        xt_rvs = self.fx(zt_rvs)
        weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, self.R(zt_rvs, obs_t))

        # 3. Resampling
        pi = random.choice(key_reindex, indices,
                           p=weights_t, shape=(nsamples,))
        zt_rvs = zt_rvs[pi, ...]
        weights_t = jnp.ones(nsamples) / nsamples

        # 4. Compute latent-state estimate,
        #    Set next covariance state matrix
        mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t)

        return (zt_rvs, key_next), mu_t
Esempio n. 6
    def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
        """Connects Module.

      inputs: Tensor of shape [..., num_channel]

      output of shape [..., num_output]
        n_channels = int(inputs.shape[-1])

        weight_shape = [n_channels, self.num_output]
        if self.initializer == 'linear':
            weight_init = hk.initializers.VarianceScaling(mode='fan_in',
        elif self.initializer == 'relu':
            weight_init = hk.initializers.VarianceScaling(mode='fan_in',
        elif self.initializer == 'zeros':
            weight_init = hk.initializers.Constant(0.0)

        weights = hk.get_parameter('weights', weight_shape, inputs.dtype,

        # this is equivalent to einsum('...c,cd->...d', inputs, weights)
        # but turns out to be slightly faster
        inputs = jnp.swapaxes(inputs, -1, -2)
        output = jnp.einsum('...cb,cd->...db', inputs, weights)
        output = jnp.swapaxes(output, -1, -2)

        if self.use_bias:
            bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
            output += bias

        return output
Esempio n. 7
    def __call__(self, row, col, features):
        attn = hk.Linear(self.num_heads * self.attention_size)(features)
        attn = jnp.reshape(attn, (-1, self.num_heads, self.attention_size))
        attn = jax.nn.sigmoid(jnp.einsum("eha,eha->eh", attn[row], attn[col]))
        epsilon = jax.nn.sigmoid(
                shape=(self.num_heads, ),
                init=lambda shape, dtype: jnp.full(shape, -2.0, dtype=dtype),
        features = hk.Linear(self.num_heads * self.out_size)(features)
        features = jnp.reshape(features, (-1, self.num_heads, self.out_size))

        def propagate(features, attn, eps, row, col):
            adj = COO((attn, row, col), shape=(features.shape[0], ) * 2)
            propagator = _get_propagator(adj, eps, self.tol)
            return propagator @ features
            # shifted_lap = _get_shifted_laplacian(adj, eps)
            # return x: shifted_lap @ x, features)

        # out_ax = 0
        # out = jax.vmap(propagate, in_axes=(1, 1, 0, None, None), out_axes=out_ax)(
        #     features, attn, epsilon, row, col
        # )
        def unstack(x, axis):
            return [x.take(i, axis=axis) for i in range(x.shape[axis])]

        features = unstack(features, axis=1)
        attn = unstack(attn, axis=1)
        eps = unstack(epsilon, axis=0)
        out = [
            propagate(f, a, e, row, col)
            for (f, a, e) in zip(features, attn, eps)
        return sum(out) / len(out)
Esempio n. 8
    def __call__(
        deterministic: bool = True,
        output_attentions: bool = False,
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        causal_attention_mask = None
        if self.causal:
            query_length, key_length = query.shape[1], key.shape[1]
            causal_attention_mask = self.causal_mask[:, :,
                                                     key_length - query_length:
                                                     key_length, :key_length]

        if attention_mask is not None and causal_attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_mask = combine_masks(attention_mask,
        elif causal_attention_mask is not None:
            attention_mask = causal_attention_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        if attention_mask is not None:
            attention_bias =
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Esempio n. 9
    def __call__(
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,

        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        sincos = jnp.take(self.embed_positions, position_ids, axis=0)
        sincos = jnp.split(sincos, 2, axis=-1)
        if self.rotary_dim is not None:
            k_rot = key[:, :, :, :self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim:]

            q_rot = query[:, :, :, :self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim:]

            k_rot = apply_rotary_pos_emb(k_rot, sincos)
            q_rot = apply_rotary_pos_emb(q_rot, sincos)

            key = jnp.concatenate([k_rot, k_pass], axis=-1)
            query = jnp.concatenate([q_rot, q_pass], axis=-1)
            key = apply_rotary_pos_emb(key, sincos)
            query = apply_rotary_pos_emb(query, sincos)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias =
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e9).astype(self.dtype),

        # usual dot product attention
        attn_weights = dot_product_attention_weights(

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Esempio n. 10
 def _check(self, s, *ops):
     a = onp.einsum(s, *ops)
     b = np.einsum(s, *ops)
     self.assertAllClose(a, b, atol=1e-4, rtol=1e-4, check_dtypes=True)
Esempio n. 11
def bulk_modulus(elastic_tensor):
  return np.einsum('iijj->', elastic_tensor) / elastic_tensor.shape[0] ** 2
Esempio n. 12
def cosine_similarity(a: Array, b: Array) -> Array:
  """Computes batched cosine similarity between two 2D arrays."""
  a_norm = jnp.linalg.norm(a, axis=-1)
  b_norm = jnp.linalg.norm(b, axis=-1)
  dot = jnp.einsum('bd,bd->b', a, b)
  return dot / (_SMALL + a_norm * b_norm)
Esempio n. 13
init_state = (mu_t, tau_t)
xs = (Phi, y)

adf_loop = partial(adf_step, q=q, lbound=lbound, ubound=ubound)
(mu_t, tau_t), (mu_t_hist, tau_t_hist) = jax.lax.scan(adf_loop, init_state, xs)

# ** Estimating posterior predictive distribution **
xmin, ymin = X.min(axis=0) - 0.1
xmax, ymax = X.max(axis=0) + 0.1
step = 0.1
Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
_, nx, ny = Xspace.shape
Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace])

# MCMC posterior predictive distribution
Z_mcmc = sigmoid(jnp.einsum("mij,sm->sij", Phispace, chains))
Z_mcmc = Z_mcmc.mean(axis=0)
# Laplace posterior predictive distribution
key = random.PRNGKey(314)
laplace_samples = random.multivariate_normal(key, w_map, SN, (n_samples, ))
Z_laplace = sigmoid(jnp.einsum("mij,sm->sij", Phispace, laplace_samples))
Z_laplace = Z_laplace.mean(axis=0)
# ADF posterior predictive distribution
adf_samples = random.multivariate_normal(key, mu_t, jnp.diag(tau_t),
                                         (n_samples, ))
Z_adf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, adf_samples))
Z_adf = Z_adf.mean(axis=0)

# ** Plotting predictive distribution **
plt.rcParams["axes.spines.right"] = False
plt.rcParams[""] = False
Esempio n. 14
    def __call__(self, input_qkv):
        cfg = self.config
        cfg.max_len % cfg.max_seg_len == 0
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        query, key, value, head_dim = get_qkv(cfg, input_qkv)

        num_seg = cfg.max_len // cfg.max_seg_len
        cur_query = query.reshape(
            [-1, cfg.max_seg_len, query.shape[-2], query.shape[-1]])
        merged_query = jnp.max(cur_query, axis=1,
                               keepdims=True) * jnp.sqrt(head_dim)
        cur_key = key.reshape(
            [-1, cfg.max_seg_len, key.shape[-2], key.shape[-1]])
        cur_value = value.reshape(
            [-1, cfg.max_seg_len, value.shape[-2], value.shape[-1]])
        dropout_rng = None
        if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')
        s = dot_product_attention(merged_query,
        span_val = jnp.reshape(s, [bsize, -1, s.shape[-2], s.shape[-1]])
        span_key = jnp.max(cur_key, axis=1, keepdims=True)
        # (bsize, n_seg, n_head, dim_per_head)
        span_key = jnp.reshape(
            span_key, [bsize, -1, span_key.shape[-2], span_key.shape[-1]])

        local_mask = make_causal_mask(cur_query,
                                      length_axis=1).transpose([0, 2, 1, 3])
        local_bias =
            local_mask > 0,
            jnp.full(local_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(local_mask.shape, -1e10).astype(cfg.dtype))
        # (bsize * n_seg, seg_len, n_head, seg_len)
        local_logits = jnp.einsum('...qhd,...khd->...qhk', cur_query,
                                  cur_key) + local_bias
        local_logits = jnp.reshape(local_logits,
                                   [bsize, -1, cfg.num_heads, cfg.max_seg_len])
        idx = jnp.broadcast_to(jnp.arange(span_key.shape[1], dtype=jnp.int32),
        prev_mask = nn.make_attention_mask(idx,
                                               [0, 2, 1, 3])
        prev_mask = jnp.repeat(prev_mask, cfg.max_seg_len, axis=-3)
        prev_bias =
            prev_mask > 0,
            jnp.full(prev_mask.shape, 0.).astype(cfg.dtype),
            jnp.full(prev_mask.shape, -1e10).astype(cfg.dtype))
        # (bsize, max_len, n_head, num_segs)
        prev_logits = jnp.einsum('...qhd,...khd->...qhk', query,
                                 span_key) + prev_bias
        joint_logits = jnp.concatenate((local_logits, prev_logits), axis=-1)
        # (bsize x max_len,  n_head, seg_len + num_segs)
        attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype)
        local_att, prev_att = jnp.split(attn_weights, [cfg.max_seg_len],
        local_att = local_att.reshape(
            [bsize * num_seg, cfg.max_seg_len, cfg.num_heads, cfg.max_seg_len])
        local_merged = jnp.einsum('...qhk,...khd->...qhd', local_att,
        prev_merged = jnp.einsum('...qhk,...khd->...qhd', prev_att, span_val)
        joint_merged = jnp.reshape(local_merged,
                                   prev_merged.shape) + prev_merged
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
        return x
Esempio n. 15
    def __call__(self, input_qkv):
        cfg = self.config
        log_len = log_2_ceil(cfg.max_len - 1)
        bsize = input_qkv.shape[0]
        features = self.out_features or input_qkv.shape[-1]
        query, key, value, head_dim = get_qkv(cfg, input_qkv)

        joint_logits = []
        list_vals = []
        for l in range(log_len):
            ctx_len = 2**l
            last_pos = cfg.max_len - cfg.max_len % ctx_len
            num_ctx = cfg.max_len // ctx_len

            if l == 0:
                span_key = jnp.reshape(key, [-1, 1, cfg.num_heads, head_dim])
                span_val = value.reshape(span_key.shape)
                self_logits = jnp.expand_dims(jnp.sum(query * key, axis=-1),
                left_query = query[:, :last_pos, :, :].reshape(
                    [-1, ctx_len, cfg.num_heads, head_dim])
                span_query = jnp.max(left_query, axis=1, keepdims=True)
                left_key = key[:, :last_pos, :, :].reshape(left_query.shape)
                left_val = value[:, :last_pos, :, :].reshape(left_query.shape)
                span_val = dot_product_attention(
                    span_query * jnp.sqrt(head_dim),
                span_key = jnp.max(left_key, axis=1, keepdims=True)
            rolled_q = jnp.roll(query, -ctx_len,
                                axis=1)[:, :last_pos, :, :].reshape(
                                    [-1, ctx_len, cfg.num_heads, head_dim])

            rolled_mask = jnp.concatenate(
                [(jnp.arange(cfg.max_len - ctx_len) // ctx_len) % 2,
                 jnp.ones(last_pos + ctx_len - cfg.max_len, dtype=jnp.int32)],
            rolled_mask = jnp.reshape(rolled_mask, [1, -1, 1, 1])
            rolled_logits = jnp.einsum('...qhd,...khd->...qhk', rolled_q,
            # bsize, last_pos, h, 1
            rolled_logits = jnp.reshape(
                rolled_logits, [bsize, -1, cfg.num_heads, 1
                                ]) + rolled_mask.astype(rolled_q.dtype) * -1e9
            orig_logits = jnp.pad(rolled_logits, [(0, 0),
                                                  (0, cfg.max_len - last_pos),
                                                  (0, 0), (0, 0)],
            orig_logits = jnp.roll(orig_logits, ctx_len, axis=1)
        joint_logits = jnp.concatenate(joint_logits, axis=-1)
        attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype)
        local_weights = jnp.split(attn_weights, log_len + 1, axis=-1)
        local_weighted_sums = []
        joint_merged = local_weights[0] * value
        for l in range(log_len):
            ctx_len = 2**l
            last_pos = cfg.max_len - cfg.max_len % ctx_len
            num_ctx = cfg.max_len // ctx_len

            rolled_w = jnp.roll(local_weights[l + 1], -ctx_len,
                                axis=1)[:, :last_pos, :, :].reshape(
                                    bsize * num_ctx, ctx_len, cfg.num_heads, 1)
            rolled_v = jnp.reshape(rolled_w * list_vals[l],
                                   [bsize, -1, cfg.num_heads, head_dim])
            rolled_v = jnp.pad(rolled_v, [(0, 0), (0, cfg.max_len - last_pos),
                                          (0, 0), (0, 0)])
            orig_v = jnp.roll(rolled_v, ctx_len, axis=1)
            joint_merged = joint_merged + orig_v
        x = DenseGeneral(features=features,
                         axis=(-2, -1),
        return x
def roll_out_transitions(
    variant_weights, start_machine_state,
    node_index, steps, rng,
  """Roll out transitions for a transition matrix.

    builder: The automaton builder associated with this Markov chain.
    transition_matrix: The per-variant transition matrix for this Markov chain.
    variant_weights: Weights to assign to each routing variant for each node, as
      a <float[num_nodes, num_variants]> array (which should sum to 1 across the
      last axis).
    start_machine_state: Initial machine state distribution for the starting
      nodes, as a <float[num_fsm_states]> array (which should sum to 1).
    node_index: Initial node index where the solve should start, as a int32
    steps: How many steps to unroll.
    rng: Random number generator used to sample.
    max_possible_transitions: Max transitions to truncate to.

    Final RollOutState.
  (variants, nodes, fsm_states, in_tagged_nodes,
   _) = transition_matrix.initial_to_in_tagged.shape
  itn_states = in_tagged_nodes * fsm_states

  def add_special_dests(in_tagged_dests):
    return jnp.concatenate(
         jnp.arange(itn_states, itn_states + 3)])

  # Collapse our transition matrix into a more useful form:
  # - Combine nodes and FSM states
  # - Extract a subset of possible destinations
  # - Combine special and move actions into a single probability table
  k = min(max_possible_transitions, in_tagged_nodes) * fsm_states
  initial_to_in_tagged_combo = transition_matrix.initial_to_in_tagged.reshape(
      (variants, nodes, fsm_states, itn_states))
  _, initial_to_in_tagged_dests = jax.lax.top_k(
      jnp.sum(initial_to_in_tagged_combo, axis=(0, 2)), k=k)
  initial_to_in_tagged_probs = jax.vmap(
      lambda c, i: c[:, :, i], in_axes=(1, 0),
      out_axes=1)(initial_to_in_tagged_combo, initial_to_in_tagged_dests)
  initial_probs = jnp.concatenate(
      [initial_to_in_tagged_probs, transition_matrix.initial_to_special], -1)
  initial_dests = jax.vmap(add_special_dests)(initial_to_in_tagged_dests)

  in_tagged_to_in_tagged_combo = transition_matrix.in_tagged_to_in_tagged.reshape(
      (variants, itn_states, itn_states))
  _, in_tagged_to_in_tagged_dests = jax.lax.top_k(
      jnp.sum(in_tagged_to_in_tagged_combo, axis=0), k=k)
  in_tagged_to_in_tagged_probs = jax.vmap(
      lambda c, i: c[:, i], in_axes=(1, 0),
      out_axes=1)(in_tagged_to_in_tagged_combo, in_tagged_to_in_tagged_dests)
  in_tagged_probs = jnp.concatenate([
      transition_matrix.in_tagged_to_special.reshape((variants, itn_states, 3))
  ], -1)
  in_tagged_dests = jax.vmap(add_special_dests)(in_tagged_to_in_tagged_dests)

  start_node_variant_weights = variant_weights[node_index]
  initial_probs_from_here = jnp.einsum("v,s,vsj->j", start_node_variant_weights,
                                       initial_probs[:, node_index, :, :])
  initial_dests_from_here = initial_dests[node_index]

  per_itn_variants = variant_weights[transition_matrix.in_tagged_node_indices]

  # Set up the initial state
  initial_state = RollOutState(rng=rng)

  def scan_body(state, ignored_input):
    assert ignored_input is None
    rng, key = jax.random.split(state.rng)

    # If we are in the initial state, sample the initial transition.
    def at_initial_info():
      return initial_probs_from_here, initial_dests_from_here

    # If we are in a normal state, sample the next action.
    def at_normal_info():
      cur_variant_weights = per_itn_variants[state.itn_state_index //
      next_step_probs = jnp.einsum("v,vj->j", cur_variant_weights,
                                   in_tagged_probs[:, state.itn_state_index, :])

      return next_step_probs, in_tagged_dests[state.itn_state_index, :]

    # Figure out which to do, and sample from the appropriate probabilities
    step_probs, step_dests = jax.tree_multimap(
        functools.partial(jnp.where, state.at_initial), at_initial_info(),

    next_idx = jax.random.categorical(key, jnp.log(step_probs))
    log_prob = jnp.log(step_probs[next_idx])
    dest = step_dests[next_idx]
    did_special = dest >= itn_states

    state_after_move = RollOutState(
        log_prob=state.log_prob + log_prob,

    special_idx = dest - itn_states
    state_after_special = RollOutState(
        at_initial=(special_idx == builder.special_actions.index(
        succeeded=(special_idx == builder.special_actions.index(
        failed=(special_idx == builder.special_actions.index(
        log_prob=state.log_prob + log_prob,

    # Choose the right branch to take
    def choose(move, special, done):
      return jnp.where(state.succeeded | state.failed, done,
                       jnp.where(did_special, special, move))

    new_state = jax.tree_multimap(choose, state_after_move, state_after_special,

    return new_state, None

  final_state, _ = jax.lax.scan(scan_body, initial_state, None, length=steps)
  final_node = jnp.where(
      final_state.special_from_initial, node_index,
      transition_matrix.in_tagged_node_indices[final_state.itn_state_index //
  final_state = dataclasses.replace(final_state, final_node=final_node)
  return final_state
Esempio n. 17
 def cell_fn(theta, state0, inputs_t):
     del state0
     y = jnp.einsum('x,xy->y', inputs_t.x, theta.proj)
     return NestedMap(y=y)
Esempio n. 18
 def batched_dot(a, b):
     if a.shape[0] != b.shape[0]:
         raise TypeError("Shapes must match in the 0-th dimension")
     if a.ndim == 2 or b.ndim == 2:
         return jnp.einsum("n...j,nj...->n...", a, b)
     return jnp.einsum("nij,njk->nik", a, b)
Esempio n. 19
                  poly_axes=[0, 0]),

    _make_harness("dynamic_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
                  [RandArg((3, 4), _f32)],

    _make_harness("dynamic_update_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
                  [RandArg((3, 4), _f32)],

    _make_harness("einsum", "0",
                  lambda x: jnp.einsum("...i->...", x),
                  [RandArg((3, 4), _f32)],

    _make_harness("einsum", "1",
                  lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y),
                  [RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)],
                  poly_axes=[0, 0]),

    _make_harness("einsum", "2",
                  lambda x, y: jnp.einsum("...ij,jk->...ik", x, y),
                  [RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)],
                  poly_axes=[0, None]),

    _make_harness("einsum", "3",
                  # Reduced dimension is polymorphic
Esempio n. 20
def oei_arrays(geom, basis, charges):
    Build one electron integral arrays (overlap, kinetic, and potential integrals)
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    nprim = coeffs.shape[0]
    max_am = jnp.max(ams)
    A_vals = jnp.zeros(2 * max_am + 1)

    # Save various AM distributions for indexing
    # Obtain all possible primitive quartet index combinations
    primitive_duets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim))

    with loops.Scope() as s:
        s.oei = jnp.zeros((3, nbf, nbf))
        s.a = 0  # center A angular momentum iterator
        s.b = 0  # center B angular momentum iterator

        for prim_duet in s.range(primitive_duets.shape[0]):
            p1, p2 = primitive_duets[prim_duet]
            coef = coeffs[p1] * coeffs[p2]
            aa, bb = exps[p1], exps[p2]
            atom1, atom2 = atoms[p1], atoms[p2]
            am1, am2 = ams[p1], ams[p2]
            A, B = geom[atom1], geom[atom2]
            ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2]

            gamma = aa + bb
            prefactor = jnp.exp(-aa * bb * - B, A - B) / gamma)
            P = (aa * A + bb * B) / gamma
            # Maximum angular momentum: hard coded
            #max_am = 3 # f function support
            # Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi.
            # We need 2+max_am since kinetic requires incrementing angluar momentum by +2
            PA_pow = jnp.power(
                jnp.broadcast_to(P - A, (max_am + 3, 3)).T,
                jnp.arange(max_am + 3))
            PB_pow = jnp.power(
                jnp.broadcast_to(P - B, (max_am + 3, 3)).T,
                jnp.arange(max_am + 3))

            # For potential integrals, we need the difference between
            # the gaussian product center P and ALL atoms in the molecule,
            # and then take all possible powers up to 2*max_am.
            # We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed.
            # The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3
            P_minus_geom = jnp.broadcast_to(P, geom.shape) - geom
            Pgeom_pow = jnp.power(
                        (2 * max_am + 1, geom.shape[0], geom.shape[1])),
                    (1, 2, 0)), jnp.arange(2 * max_am + 1))
            # All possible,P-atom)
            rcp2 = jnp.einsum('ij,ij->i', P_minus_geom, P_minus_geom)
            # All needed (and unneeded, for am < max_am) boys function evaluations
            boys_arg = jnp.broadcast_to(rcp2 * gamma,
                                        (2 * max_am + 1, geom.shape[0]))
            boys_nu = jnp.tile(jnp.arange(2 * max_am + 1),
                               (geom.shape[0], 1)).T
            boys_eval = boys(boys_nu, boys_arg)

            s.a = 0
            for _ in s.while_range(lambda: s.a < dims[p1]):
                s.b = 0
                for _ in s.while_range(lambda: s.b < dims[p2]):
                    # Gather angular momentum and index
                    la, ma, na = angular_momentum_combinations[s.a + ld1]
                    lb, mb, nb = angular_momentum_combinations[s.b + ld2]
                    # To only create unique indices, need to have separate indices arrays for i and j.
                    i = indices[p1] + s.a
                    j = indices[p2] + s.b
                    # Compute one electron integrals and add to appropriate index
                    overlap_int = overlap(la, ma, na, lb, mb, nb, aa, bb,
                                          PA_pow, PB_pow, prefactor) * coef
                    kinetic_int = kinetic(la, ma, na, lb, mb, nb, aa, bb,
                                          PA_pow, PB_pow, prefactor) * coef
                    potential_int = potential(la, ma, na, lb, mb, nb, aa, bb,
                                              PA_pow, PB_pow, Pgeom_pow,
                                              boys_eval, prefactor, charges,
                                              A_vals) * coef
                    s.oei = jax.ops.index_add(
                        s.oei, ([0, 1, 2], [i, i, i], [j, j, j]),
                        (overlap_int, kinetic_int, potential_int))

                    s.b += 1
                s.a += 1
    S, T, V = s.oei[0], s.oei[1], s.oei[2]
    return S, T, V
Esempio n. 21
def raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, rng=None):
    """Transforms model's predictions to semantically meaningful values.
        raw: (num_rays, num_samples || num_importance, 4) prediction from model
        z_vals: (num_rays, num_samples || num_importance) integration time
        rays_d: (num_rays, 3) direction of each ray
        raw_noise_std: std of noise added for regularization
        white_bkgd: whether to use the alpha channel for white background
        rng: random key
        acc_map: (num_rays) sum of weights along each ray
        depth_map: (num_rays) estimated distance to object
        disp_map: (num_rays) disparity map (inverse of depth map)
        rgb_map: (num_rays, 3) estimated RGB color of a ray
        weights: (num_rays, num_samples || num_importance) weights assigned to each sampled color

    # compute 'distance' (in time) between each integration time along a ray
    dists = z_vals[..., 1:] - z_vals[..., :-1]

    # the 'distance' from the last integration time is infinity
    dists = jnp.concatenate(
        [dists, jnp.broadcast_to([1e10], dists[..., :1].shape)], axis=-1)
    dists = dists.astype(z_vals.dtype)  # [num_rays, num_samples]

    # multiply each distance by the norm of its corresponding direction ray
    # to convert to real world distance (accounts for non-unit directions)
    dists = dists * jnp.linalg.norm(rays_d[..., None, :], axis=-1)

    # extract RGB of each sample position along each ray
    rgb = nn.sigmoid(raw[..., :3])  # [num_rays, num_samples, 3]

    # add noise to predictions for density, can be used to (this value is strictly between [0, 1])
    # regularize network during training (prevents floater artifacts)
    noise = 0.0
    if raw_noise_std > 0.0 and rng is not None:
        noise = random.normal(rng, raw[..., 3].shape) * raw_noise_std

    # predict density of each sample along each ray (alpha channel)
    # higher values imply higher likelihood of being absorbed at this point
    alpha = 1.0 - jnp.exp(-nn.relu(raw[..., 3] + noise) * dists)

    # compute weight for RGB of each sample along each ray
    # cumprod() is used to express the idea of the ray not having reflected up to this sample yet
    # weights = alpha * tf.math.cumprod(1.0 - alpha + 1e-10, axis=-1, exclusive=True)
    alpha_ = jnp.clip(1.0 - alpha, 1e-5, 1.0)
    weights = jnp.concatenate(
        [jnp.ones_like(alpha_[..., :1]), alpha_[..., :-1]], -1)
    weights = alpha * jnp.cumprod(weights, -1)  # [num_rays, num_samples]

    # computed weighted color of each sample along each ray
    rgb_map = jnp.einsum("ij,ijk->ik", weights, rgb)  # [num_rays, 3]

    # estimated depth map is expected distance
    depth_map = jnp.einsum("ij,ij->i", weights, z_vals)  # [num_rays]

    # sum of weights along each ray (this value is in [0, 1] up to numerical error)
    acc_map = jnp.einsum("ij->i", weights)  # [num_rays]

    # disparity map is inverse depth
    i_depth = depth_map / jnp.clip(acc_map, 1e-5)
    disp_map = 1.0 / jnp.clip(i_depth, 1e-5)

    # to composite onto a white background, use the accumulated alpha map
    if white_bkgd:
        rgb_map += 1.0 - acc_map[..., None]

    return {
        "rgb": rgb_map.astype(jnp.float32),
        "disp": disp_map.astype(jnp.float32),
        "acc": acc_map.astype(jnp.float32),
        "depth": depth_map.astype(jnp.float32),
    }, weights
Esempio n. 22
 def func(x):
     return np.einsum("ij,ij...->...", aux, x)
Esempio n. 23
def restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, options, deriv_order=0, return_aux_data=False):
    # Load keyword options
    maxit = options['maxit']
    damping = options['damping']
    damp_factor = options['damp_factor']
    spectral_shift = options['spectral_shift']
    convergence = 1e-10

    nelectrons = int(jnp.sum(nuclear_charges)) - charge
    ndocc = nelectrons // 2

    # If we are doing MP2 or CCSD after, might as well use jit-compiled JK-build, since HF will not be memory bottleneck
    if return_aux_data:
        jk_build = jax.jit(jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None)))
        jk_build = jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None))

    # Canonical orthogonalization via cholesky decomposition
    S, T, V, G = compute_integrals(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order, options)
    A = cholesky_orthogonalization(S)

    nbf = S.shape[0]

    # For slightly shifting eigenspectrum of transformed Fock for degenerate eigenvalues 
    # (JAX cannot differentiate degenerate eigenvalue eigh) 
    if spectral_shift:
        # Shifting eigenspectrum requires lower convergence.
        convergence = 1e-8 
        fudge = jnp.asarray(np.linspace(0, 1, nbf)) * convergence
        shift = jnp.diag(fudge)
        shift = jnp.zeros_like(S)

    H = T + V
    Enuc = nuclear_repulsion(geom.reshape(-1,3),nuclear_charges)
    D = jnp.zeros_like(H)
    def rhf_iter(F,D):
        E_scf = jnp.einsum('pq,pq->', F + H, D) + Enuc
        Fp =,, A))
        Fp = Fp + shift 
        eps, C2 = jnp.linalg.eigh(Fp)
        C =,C2)
        Cocc = C[:, :ndocc]
        D =, Cocc.T)
        return E_scf, D, C, eps

    iteration = 0
    E_scf = 1.0
    E_old = 0.0
    Dold = jnp.zeros_like(D)
    dRMS = 1.0

    # Converge according to energy and DIIS residual to ensure eigenvalues and eigenvectors are maximally converged.
    # This is crucial for numerical stability for higher order derivatives of correlated methods.
    while ((abs(E_scf - E_old) > convergence) or (dRMS > convergence)):
        E_old = E_scf * 1
        if damping:
            if iteration < 10:
                D = Dold * damp_factor + D * damp_factor
                Dold = D * 1
        # Build JK matrix: 2 * J - K
        JK = 2 * jk_build(G, D)
        JK -= jk_build(G.transpose((0,2,1,3)), D)
        # Build Fock
        F = H + JK
        # Update convergence error
        if iteration > 1:
            diis_e = jnp.einsum('ij,jk,kl->il', F, D, S) - jnp.einsum('ij,jk,kl->il', S, D, F)
            diis_e =
            dRMS = jnp.mean(diis_e**2)**0.5
        # Compute energy, transform Fock and diagonalize, get new density
        E_scf, D, C, eps = rhf_iter(F,D)
        iteration += 1
        if iteration == maxit:
    print(iteration, " RHF iterations performed")

    # If many orbitals are degenerate, warn that higher order derivatives may be unstable 
    tmp = jnp.round(eps,6)
    ndegen_orbs =  tmp.shape[0] - jnp.unique(tmp).shape[0] 
    if (ndegen_orbs / nbf) > 0.20:
        print("Hartree-Fock warning: More than 20% of orbitals have degeneracies. Higher order derivatives may be unstable due to eigendecomposition AD rule")
    if not return_aux_data:
        return E_scf
        return E_scf, C, eps, G
Esempio n. 24
def fixed_pos_embedding(seq, dim):
    inv_freq = 1.0 / (10000**(np.arange(0, dim, 2) / dim))

    sinusoid_inp = np.einsum("i , j -> i j", np.arange(seq), inv_freq)

    return np.sin(sinusoid_inp), np.cos(sinusoid_inp)
Esempio n. 25
 def f(scale):
   scaled_mat = scale * psd_mat
   chol = jnp.linalg.cholesky(scaled_mat)
   return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2)
Qt = jnp.eye(2) * 0.001
# Observed noise
Rt = jnp.eye(2) * 0.01

mu0 = jnp.array([1, 1])
Sigma0 = jnp.eye(2)

key = random.PRNGKey(314)
kf = lds.ContinuousKalmanFilter(A, C, Qt, Rt, x0, Sigma0)
sample_state, sample_obs, jump = kf.sample(key, x0, T, nsamples)
mu_hist, V_hist, *_ = kf.filter(sample_obs, jump, dt)

step = 0.1
vmin, vmax = -1.5, 1.5 + step
X = np.mgrid[-1:1.5:step, vmin:vmax:step][::-1]
X_dot = jnp.einsum("ij,jxy->ixy", A, X)

fig, ax = plt.subplots()
ax.plot(*sample_state.T, label="state space")
ax.scatter(*sample_obs.T, marker="+", c="tab:green", s=60, label="observations")
ax.scatter(*sample_state[0], c="black", zorder=3)
field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa")
ax.set_title("State Space")

fig, ax = plt.subplots()
ax.plot(*mu_hist.T, c="tab:orange", label="Filtered")
ax.scatter(*sample_obs.T, marker="+", s=60, c="tab:green", label="observations")
ax.scatter(*mu_hist[0], c="black", zorder=3)
Esempio n. 27
for n in range(N):
    for c in range(C):
        s = 0
        for d in range(D):
            for k in range(K):
                for t in range(T):
                    s += S[n, t, k] * W[k, d] * V[d, c]
        L[n, c] = s
assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V))

path = np.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V, optimize=path))

import jax.numpy as jnp
path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
assert np.allclose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path))

# Use full student network from KOller and Friedman
str = 'c,dc,gdi,si,lg,jls,hgj->'
K = 5
cptC = np.random.randn(K)
cptD = np.random.randn(K, K)
cptG = np.random.randn(K, K, K)
cptS = np.random.randn(K, K)
cptL = np.random.randn(K, K)
cptJ = np.random.randn(K, K, K)
cptH = np.random.randn(K, K, K)
cpts = [cptC, cptD, cptG, cptS, cptL, cptJ, cptH]
path_info = np.einsum_path(str, *cpts, optimize='optimal')
      )  # 'einsum_path', (0, 1), (0, 5), (0, 4), (0, 3), (0, 2), (0, 1)]
 def body(p, qkv):
   (q, k, v) = qkv
   p += jnp.einsum('...m,...d->', k, v, precision=precision)
   X_slice = jnp.einsum('...m,>...d', q, p, precision=precision)
   return p, X_slice
Esempio n. 29
def dot_product_attention(query: Array,
                          key: Array,
                          value: Array,
                          bias: Optional[Array] = None,
                          broadcast_dropout: bool = True,
                          dropout_rng: Optional[PRNGKey] = None,
                          dropout_rate: float = 0.,
                          deterministic: bool = False,
                          dtype: Dtype = jnp.float32,
                          precision: Optional[lax.Precision] = None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on It calculates the attention weights given
  query and key and combines the values using the attention weights.

  Note: query, key, value needn't have any batch dimensions.

    query: queries for calculating attention with shape of
      `[batch..., q_length, num_heads, qk_depth_per_head]`.
    key: keys for calculating attention with shape of
      `[batch..., kv_length, num_heads, qk_depth_per_head]`.
    value: values to be used in attention with shape of
      `[batch..., kv_length, num_heads, v_depth_per_head]`.
    bias: bias for the attention weights. This should be broadcastable to the
      shape: `[batch..., num_heads, q_length, kv_length]`
      This can be used for incorporating causal masks, padding masks,
      proximity bias, etc.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    dtype: the dtype of the computation (default: float32)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

    Output of shape `[batch..., length, num_heads, v_depth_per_head]`.
    assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
    assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
        "q, k, v batch dims must match.")
    assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
        "q, k, v num_heads must match.")
    assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
    assert query.shape[-1] == key.shape[-1], "q, k depths must match."

    # calculate attention matrix
    depth = query.shape[-1]
    query = query / jnp.sqrt(depth).astype(dtype)
    # attn weight shape is (batch..., num_heads, q_length, kv_length)
    attn_weights = jnp.einsum('...qhd,...khd->...hqk',

    # apply attention bias: masking, dropout, proximity bias, etc.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    attn_weights = jax.nn.softmax(attn_weights).astype(dtype)

    # apply attention dropout
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        if broadcast_dropout:
            # dropout is broadcast across the batch + head dimensions
            dropout_shape = tuple([1] *
                                  (key.ndim - 2)) + attn_weights.shape[-2:]
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    # return weighted sum over values for each query position
    return jnp.einsum('...hqk,...khd->...qhd',
 def body(p, qk):
   q, k = qk
   p += k
   x = jnp.einsum('...m,...m->...', q, p, precision=precision)
   return p, x