def reweight_cl(weights, ngals, cl_in): """ """ # assert len(weights) == len(ngals) nprobe = weights.shape[0] offset = 0 w = [None] * nprobe nzbin = np.array([len(W) for W in weights]) nout = np.sum(nzbin * (1 + np.arange(nprobe))) cl_out = [None] * nout for i1 in range(nprobe): nrow = len(weights[i1]) rowstep = nprobe - i1 for i2 in range(i1, nprobe): #assert weights[i2].shape[1] == len(ngals[i2]) W = weights[i2] * ngals[i2] W /= jnp.sum(W, axis=1, keepdims=True) w[i2] = jnp.nan_to_num(W, posinf=1e30, neginf=-1e30, copy=False) cl = jnp.einsum('ip,spqk,jq->sijk', w[i1], cl_in[i2][i1], w[i2]) for j in range(nrow): start = j if i1 == i2 else 0 cl_out[offset + j * rowstep + i2 - i1] = cl[:, j, start:] offset += nrow * rowstep return jnp.concatenate(cl_out, axis=1)
def extract_outputs_and_targets( model, padded_example_and_rng, target_edge_index, num_edge_types, ): """Extract model outputs and targets for an example. Args: model: Model to run on the example. padded_example_and_rng: Example to extract targets from, with RNG. target_edge_index: Index of the target edge type. num_edge_types: How many edge types there are. Returns: Tuple (output_logits, targets, valid_mask, num_nodes, captured) """ padded_example, rng = padded_example_and_rng # Run the model. with side_outputs.collect_side_outputs() as captured: with flax.nn.stochastic(rng): output_logits = model(padded_example) # Extract targets. targets = padded_example.edges.apply_add(in_array=( jnp.arange(num_edge_types) == target_edge_index).astype("int32"), out_array=jnp.zeros( output_logits.shape, dtype="int32")).astype("bool") targets = preprocess_targets(targets) # Compute valid mask for outputs and targets. max_num_nodes = output_logits.shape[0] num_nodes = padded_example.graph_metadata.num_nodes valid_nodes = jnp.arange(max_num_nodes) < num_nodes valid_nodes_float = valid_nodes.astype("float32") valid_mask = jnp.einsum("i,j->ij", valid_nodes_float, valid_nodes_float) return output_logits, targets, valid_mask, num_nodes, captured
def __call__(self, inputs): """Applies layer to input. Args: inputs: jnp.ndarray of shape [ens_size * batch_size, ..., input_dim]. Returns: jnp.ndarray of shape [ens_size * batch_size, ..., features]. """ dtype = self.dtype or inputs.dtype inputs = jnp.asarray(inputs, dtype) input_dim = inputs.shape[-1] kernel = self.param('kernel', self.kernel_init, (input_dim, self.features), dtype) alpha = self.param('fast_weight_alpha', self.alpha_init, (self.ens_size, input_dim), dtype) gamma = self.param('fast_weight_gamma', self.gamma_init, (self.ens_size, self.features), dtype) inputs_shape = inputs.shape inputs = jnp.reshape(inputs, (self.ens_size, -1) + inputs_shape[1:]) outputs = jnp.einsum('E...C,EC,CD,ED->E...D', inputs, alpha, kernel, gamma) if self.use_ensemble_bias: bias = self.param('bias', self.bias_init, (self.ens_size, self.features), dtype) bias_shape = (self.ens_size, ) + (1, ) * (outputs.ndim - 2) + ( self.features, ) outputs = outputs + jnp.reshape(bias, bias_shape) if self.activation is not None: outputs = self.activation(outputs) # pylint: disable=not-callable return jnp.reshape(outputs, inputs_shape[:-1] + (self.features, ))
def test_einsum_kpmurphy_example(self): # code from an email with @murphyk N, C, D, K, T = 2, 3, 4, 5, 6 r = self.rng() S = r.randn(N, T, K) W = r.randn(K, D) V = r.randn(D, C) L = np.zeros((N, C)) for n in range(N): for c in range(C): s = 0 for d in range(D): for k in range(K): for t in range(T): s += S[n, t, k] * W[k, d] * V[d, c] L[n, c] = s path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0] rtol = 1e-2 if jtu.device_under_test() == "tpu" else None self.assertAllClose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path), check_dtypes=False, rtol=rtol)
def __filter_step(self, state, obs_t): nsamples = self.nsamples indices = jnp.arange(nsamples) zt_rvs, key_t = state key_t, key_reindex, key_next = random.split(key_t, 3) # 1. Draw new points from the dynamic model zt_rvs = random.multivariate_normal(key_t, self.fz(zt_rvs), self.Q(zt_rvs)) # 2. Calculate unnormalised weights xt_rvs = self.fx(zt_rvs) weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, self.R(zt_rvs, obs_t)) # 3. Resampling pi = random.choice(key_reindex, indices, p=weights_t, shape=(nsamples,)) zt_rvs = zt_rvs[pi, ...] weights_t = jnp.ones(nsamples) / nsamples # 4. Compute latent-state estimate, # Set next covariance state matrix mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t) return (zt_rvs, key_next), mu_t
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """Connects Module. Args: inputs: Tensor of shape [..., num_channel] Returns: output of shape [..., num_output] """ n_channels = int(inputs.shape[-1]) weight_shape = [n_channels, self.num_output] if self.initializer == 'linear': weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.) elif self.initializer == 'relu': weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.) elif self.initializer == 'zeros': weight_init = hk.initializers.Constant(0.0) weights = hk.get_parameter('weights', weight_shape, inputs.dtype, weight_init) # this is equivalent to einsum('...c,cd->...d', inputs, weights) # but turns out to be slightly faster inputs = jnp.swapaxes(inputs, -1, -2) output = jnp.einsum('...cb,cd->...db', inputs, weights) output = jnp.swapaxes(output, -1, -2) if self.use_bias: bias = hk.get_parameter('bias', [self.num_output], inputs.dtype, hk.initializers.Constant(self.bias_init)) output += bias return output
def __call__(self, row, col, features): attn = hk.Linear(self.num_heads * self.attention_size)(features) attn = jnp.reshape(attn, (-1, self.num_heads, self.attention_size)) attn = jax.nn.sigmoid(jnp.einsum("eha,eha->eh", attn[row], attn[col])) epsilon = jax.nn.sigmoid( hk.get_parameter( "epsilon_sig_inv", shape=(self.num_heads, ), dtype=features.dtype, init=lambda shape, dtype: jnp.full(shape, -2.0, dtype=dtype), )) features = hk.Linear(self.num_heads * self.out_size)(features) features = jnp.reshape(features, (-1, self.num_heads, self.out_size)) def propagate(features, attn, eps, row, col): adj = COO((attn, row, col), shape=(features.shape[0], ) * 2) propagator = _get_propagator(adj, eps, self.tol) return propagator @ features # shifted_lap = _get_shifted_laplacian(adj, eps) # return jax.scipy.sparse.linalg.cg(lambda x: shifted_lap @ x, features) # out_ax = 0 # out = jax.vmap(propagate, in_axes=(1, 1, 0, None, None), out_axes=out_ax)( # features, attn, epsilon, row, col # ) def unstack(x, axis): return [x.take(i, axis=axis) for i in range(x.shape[axis])] features = unstack(features, axis=1) attn = unstack(attn, axis=1) eps = unstack(epsilon, axis=0) out = [ propagate(f, a, e, row, col) for (f, a, e) in zip(features, attn, eps) ] return sum(out) / len(out)
def __call__( self, hidden_states, attention_mask=None, deterministic: bool = True, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) causal_attention_mask = None if self.causal: query_length, key_length = query.shape[1], key.shape[1] causal_attention_mask = self.causal_mask[:, :, key_length - query_length: key_length, :key_length] if attention_mask is not None and causal_attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4") elif causal_attention_mask is not None: attention_mask = causal_attention_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) if attention_mask is not None: attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
def __call__( self, hidden_states, attention_mask, position_ids, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) sincos = jnp.take(self.embed_positions, position_ids, axis=0) sincos = jnp.split(sincos, 2, axis=-1) if self.rotary_dim is not None: k_rot = key[:, :, :, :self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim:] q_rot = query[:, :, :, :self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim:] k_rot = apply_rotary_pos_emb(k_rot, sincos) q_rot = apply_rotary_pos_emb(q_rot, sincos) key = jnp.concatenate([k_rot, k_pass], axis=-1) query = jnp.concatenate([q_rot, q_pass], axis=-1) else: key = apply_rotary_pos_emb(key, sincos) query = apply_rotary_pos_emb(query, sincos) query_length, key_length = query.shape[1], key.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache( key, value, query, attention_mask) # transform boolean mask into float mask attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) # usual dot product attention attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
def _check(self, s, *ops): a = onp.einsum(s, *ops) b = np.einsum(s, *ops) self.assertAllClose(a, b, atol=1e-4, rtol=1e-4, check_dtypes=True)
def bulk_modulus(elastic_tensor): return np.einsum('iijj->', elastic_tensor) / elastic_tensor.shape[0] ** 2
def cosine_similarity(a: Array, b: Array) -> Array: """Computes batched cosine similarity between two 2D arrays.""" a_norm = jnp.linalg.norm(a, axis=-1) b_norm = jnp.linalg.norm(b, axis=-1) dot = jnp.einsum('bd,bd->b', a, b) return dot / (_SMALL + a_norm * b_norm)
init_state = (mu_t, tau_t) xs = (Phi, y) adf_loop = partial(adf_step, q=q, lbound=lbound, ubound=ubound) (mu_t, tau_t), (mu_t_hist, tau_t_hist) = jax.lax.scan(adf_loop, init_state, xs) # ** Estimating posterior predictive distribution ** xmin, ymin = X.min(axis=0) - 0.1 xmax, ymax = X.max(axis=0) + 0.1 step = 0.1 Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step] _, nx, ny = Xspace.shape Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace]) # MCMC posterior predictive distribution Z_mcmc = sigmoid(jnp.einsum("mij,sm->sij", Phispace, chains)) Z_mcmc = Z_mcmc.mean(axis=0) # Laplace posterior predictive distribution key = random.PRNGKey(314) laplace_samples = random.multivariate_normal(key, w_map, SN, (n_samples, )) Z_laplace = sigmoid(jnp.einsum("mij,sm->sij", Phispace, laplace_samples)) Z_laplace = Z_laplace.mean(axis=0) # ADF posterior predictive distribution adf_samples = random.multivariate_normal(key, mu_t, jnp.diag(tau_t), (n_samples, )) Z_adf = sigmoid(jnp.einsum("mij,sm->sij", Phispace, adf_samples)) Z_adf = Z_adf.mean(axis=0) # ** Plotting predictive distribution ** plt.rcParams["axes.spines.right"] = False plt.rcParams["axes.spines.top"] = False
def __call__(self, input_qkv): cfg = self.config cfg.max_len % cfg.max_seg_len == 0 bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] query, key, value, head_dim = get_qkv(cfg, input_qkv) num_seg = cfg.max_len // cfg.max_seg_len cur_query = query.reshape( [-1, cfg.max_seg_len, query.shape[-2], query.shape[-1]]) merged_query = jnp.max(cur_query, axis=1, keepdims=True) * jnp.sqrt(head_dim) cur_key = key.reshape( [-1, cfg.max_seg_len, key.shape[-2], key.shape[-1]]) cur_value = value.reshape( [-1, cfg.max_seg_len, value.shape[-2], value.shape[-1]]) dropout_rng = None if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') s = dot_product_attention(merged_query, cur_key, cur_value, dropout_rng=dropout_rng, dropout_rate=cfg.attention_dropout_rate, broadcast_dropout=False, deterministic=cfg.deterministic, dtype=cfg.dtype) span_val = jnp.reshape(s, [bsize, -1, s.shape[-2], s.shape[-1]]) span_key = jnp.max(cur_key, axis=1, keepdims=True) # (bsize, n_seg, n_head, dim_per_head) span_key = jnp.reshape( span_key, [bsize, -1, span_key.shape[-2], span_key.shape[-1]]) local_mask = make_causal_mask(cur_query, length_axis=1).transpose([0, 2, 1, 3]) local_bias = lax.select( local_mask > 0, jnp.full(local_mask.shape, 0.).astype(cfg.dtype), jnp.full(local_mask.shape, -1e10).astype(cfg.dtype)) # (bsize * n_seg, seg_len, n_head, seg_len) local_logits = jnp.einsum('...qhd,...khd->...qhk', cur_query, cur_key) + local_bias local_logits = jnp.reshape(local_logits, [bsize, -1, cfg.num_heads, cfg.max_seg_len]) idx = jnp.broadcast_to(jnp.arange(span_key.shape[1], dtype=jnp.int32), span_key.shape[:2]) prev_mask = nn.make_attention_mask(idx, idx, jnp.greater, extra_batch_dims=0, dtype=jnp.float32).transpose( [0, 2, 1, 3]) prev_mask = jnp.repeat(prev_mask, cfg.max_seg_len, axis=-3) prev_bias = lax.select( prev_mask > 0, jnp.full(prev_mask.shape, 0.).astype(cfg.dtype), jnp.full(prev_mask.shape, -1e10).astype(cfg.dtype)) # (bsize, max_len, n_head, num_segs) prev_logits = jnp.einsum('...qhd,...khd->...qhk', query, span_key) + prev_bias joint_logits = jnp.concatenate((local_logits, prev_logits), axis=-1) # (bsize x max_len, n_head, seg_len + num_segs) attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype) local_att, prev_att = jnp.split(attn_weights, [cfg.max_seg_len], axis=-1) local_att = local_att.reshape( [bsize * num_seg, cfg.max_seg_len, cfg.num_heads, cfg.max_seg_len]) local_merged = jnp.einsum('...qhk,...khd->...qhd', local_att, cur_value) prev_merged = jnp.einsum('...qhk,...khd->...qhd', prev_att, span_val) joint_merged = jnp.reshape(local_merged, prev_merged.shape) + prev_merged x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def __call__(self, input_qkv): cfg = self.config log_len = log_2_ceil(cfg.max_len - 1) bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] query, key, value, head_dim = get_qkv(cfg, input_qkv) joint_logits = [] list_vals = [] for l in range(log_len): ctx_len = 2**l last_pos = cfg.max_len - cfg.max_len % ctx_len num_ctx = cfg.max_len // ctx_len if l == 0: span_key = jnp.reshape(key, [-1, 1, cfg.num_heads, head_dim]) span_val = value.reshape(span_key.shape) self_logits = jnp.expand_dims(jnp.sum(query * key, axis=-1), -1) joint_logits.append(self_logits) else: left_query = query[:, :last_pos, :, :].reshape( [-1, ctx_len, cfg.num_heads, head_dim]) span_query = jnp.max(left_query, axis=1, keepdims=True) left_key = key[:, :last_pos, :, :].reshape(left_query.shape) left_val = value[:, :last_pos, :, :].reshape(left_query.shape) span_val = dot_product_attention( span_query * jnp.sqrt(head_dim), left_key, left_val, dropout_rng=self.get_dropout_png(cfg), dropout_rate=cfg.attention_dropout_rate, broadcast_dropout=False, deterministic=cfg.deterministic, dtype=cfg.dtype) span_key = jnp.max(left_key, axis=1, keepdims=True) rolled_q = jnp.roll(query, -ctx_len, axis=1)[:, :last_pos, :, :].reshape( [-1, ctx_len, cfg.num_heads, head_dim]) rolled_mask = jnp.concatenate( [(jnp.arange(cfg.max_len - ctx_len) // ctx_len) % 2, jnp.ones(last_pos + ctx_len - cfg.max_len, dtype=jnp.int32)], axis=0) rolled_mask = jnp.reshape(rolled_mask, [1, -1, 1, 1]) rolled_logits = jnp.einsum('...qhd,...khd->...qhk', rolled_q, span_key) # bsize, last_pos, h, 1 rolled_logits = jnp.reshape( rolled_logits, [bsize, -1, cfg.num_heads, 1 ]) + rolled_mask.astype(rolled_q.dtype) * -1e9 orig_logits = jnp.pad(rolled_logits, [(0, 0), (0, cfg.max_len - last_pos), (0, 0), (0, 0)], constant_values=-1e9) orig_logits = jnp.roll(orig_logits, ctx_len, axis=1) joint_logits.append(orig_logits) list_vals.append(span_val) joint_logits = jnp.concatenate(joint_logits, axis=-1) attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype) local_weights = jnp.split(attn_weights, log_len + 1, axis=-1) local_weighted_sums = [] joint_merged = local_weights[0] * value for l in range(log_len): ctx_len = 2**l last_pos = cfg.max_len - cfg.max_len % ctx_len num_ctx = cfg.max_len // ctx_len rolled_w = jnp.roll(local_weights[l + 1], -ctx_len, axis=1)[:, :last_pos, :, :].reshape( bsize * num_ctx, ctx_len, cfg.num_heads, 1) rolled_v = jnp.reshape(rolled_w * list_vals[l], [bsize, -1, cfg.num_heads, head_dim]) rolled_v = jnp.pad(rolled_v, [(0, 0), (0, cfg.max_len - last_pos), (0, 0), (0, 0)]) orig_v = jnp.roll(rolled_v, ctx_len, axis=1) joint_merged = joint_merged + orig_v x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def roll_out_transitions( builder, transition_matrix, variant_weights, start_machine_state, node_index, steps, rng, max_possible_transitions): """Roll out transitions for a transition matrix. Args: builder: The automaton builder associated with this Markov chain. transition_matrix: The per-variant transition matrix for this Markov chain. variant_weights: Weights to assign to each routing variant for each node, as a <float[num_nodes, num_variants]> array (which should sum to 1 across the last axis). start_machine_state: Initial machine state distribution for the starting nodes, as a <float[num_fsm_states]> array (which should sum to 1). node_index: Initial node index where the solve should start, as a int32 scalar. steps: How many steps to unroll. rng: Random number generator used to sample. max_possible_transitions: Max transitions to truncate to. Returns: Final RollOutState. """ (variants, nodes, fsm_states, in_tagged_nodes, _) = transition_matrix.initial_to_in_tagged.shape itn_states = in_tagged_nodes * fsm_states def add_special_dests(in_tagged_dests): return jnp.concatenate( [in_tagged_dests, jnp.arange(itn_states, itn_states + 3)]) # Collapse our transition matrix into a more useful form: # - Combine nodes and FSM states # - Extract a subset of possible destinations # - Combine special and move actions into a single probability table k = min(max_possible_transitions, in_tagged_nodes) * fsm_states initial_to_in_tagged_combo = transition_matrix.initial_to_in_tagged.reshape( (variants, nodes, fsm_states, itn_states)) _, initial_to_in_tagged_dests = jax.lax.top_k( jnp.sum(initial_to_in_tagged_combo, axis=(0, 2)), k=k) initial_to_in_tagged_probs = jax.vmap( lambda c, i: c[:, :, i], in_axes=(1, 0), out_axes=1)(initial_to_in_tagged_combo, initial_to_in_tagged_dests) initial_probs = jnp.concatenate( [initial_to_in_tagged_probs, transition_matrix.initial_to_special], -1) initial_dests = jax.vmap(add_special_dests)(initial_to_in_tagged_dests) in_tagged_to_in_tagged_combo = transition_matrix.in_tagged_to_in_tagged.reshape( (variants, itn_states, itn_states)) _, in_tagged_to_in_tagged_dests = jax.lax.top_k( jnp.sum(in_tagged_to_in_tagged_combo, axis=0), k=k) in_tagged_to_in_tagged_probs = jax.vmap( lambda c, i: c[:, i], in_axes=(1, 0), out_axes=1)(in_tagged_to_in_tagged_combo, in_tagged_to_in_tagged_dests) in_tagged_probs = jnp.concatenate([ in_tagged_to_in_tagged_probs, transition_matrix.in_tagged_to_special.reshape((variants, itn_states, 3)) ], -1) in_tagged_dests = jax.vmap(add_special_dests)(in_tagged_to_in_tagged_dests) start_node_variant_weights = variant_weights[node_index] initial_probs_from_here = jnp.einsum("v,s,vsj->j", start_node_variant_weights, start_machine_state, initial_probs[:, node_index, :, :]) initial_dests_from_here = initial_dests[node_index] per_itn_variants = variant_weights[transition_matrix.in_tagged_node_indices] # Set up the initial state initial_state = RollOutState(rng=rng) @jax.remat def scan_body(state, ignored_input): assert ignored_input is None rng, key = jax.random.split(state.rng) # If we are in the initial state, sample the initial transition. def at_initial_info(): return initial_probs_from_here, initial_dests_from_here # If we are in a normal state, sample the next action. def at_normal_info(): cur_variant_weights = per_itn_variants[state.itn_state_index // fsm_states] next_step_probs = jnp.einsum("v,vj->j", cur_variant_weights, in_tagged_probs[:, state.itn_state_index, :]) return next_step_probs, in_tagged_dests[state.itn_state_index, :] # Figure out which to do, and sample from the appropriate probabilities step_probs, step_dests = jax.tree_multimap( functools.partial(jnp.where, state.at_initial), at_initial_info(), at_normal_info()) next_idx = jax.random.categorical(key, jnp.log(step_probs)) log_prob = jnp.log(step_probs[next_idx]) dest = step_dests[next_idx] did_special = dest >= itn_states state_after_move = RollOutState( at_initial=False, succeeded=False, failed=False, special_from_initial=False, itn_state_index=dest, final_node=None, log_prob=state.log_prob + log_prob, rng=rng) special_idx = dest - itn_states state_after_special = RollOutState( at_initial=(special_idx == builder.special_actions.index( automaton_builder.SpecialActions.BACKTRACK)), succeeded=(special_idx == builder.special_actions.index( automaton_builder.SpecialActions.FINISH)), failed=(special_idx == builder.special_actions.index( automaton_builder.SpecialActions.FAIL)), special_from_initial=state.at_initial, itn_state_index=state.itn_state_index, final_node=None, log_prob=state.log_prob + log_prob, rng=rng) # Choose the right branch to take def choose(move, special, done): return jnp.where(state.succeeded | state.failed, done, jnp.where(did_special, special, move)) new_state = jax.tree_multimap(choose, state_after_move, state_after_special, state) return new_state, None final_state, _ = jax.lax.scan(scan_body, initial_state, None, length=steps) final_node = jnp.where( final_state.special_from_initial, node_index, transition_matrix.in_tagged_node_indices[final_state.itn_state_index // fsm_states]) final_state = dataclasses.replace(final_state, final_node=final_node) return final_state
def cell_fn(theta, state0, inputs_t): del state0 y = jnp.einsum('x,xy->y', inputs_t.x, theta.proj) return NestedMap(y=y)
def batched_dot(a, b): if a.shape[0] != b.shape[0]: raise TypeError("Shapes must match in the 0-th dimension") if a.ndim == 2 or b.ndim == 2: return jnp.einsum("n...j,nj...->n...", a, b) return jnp.einsum("nij,njk->nik", a, b)
poly_axes=[0, 0]), _make_harness("dynamic_slice", "", # x:shape: (b, 4) lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)), [RandArg((3, 4), _f32)], poly_axes=[0]), _make_harness("dynamic_update_slice", "", # x:shape: (b, 4) lambda x: lax.dynamic_update_slice(x, x, (0, 0)), [RandArg((3, 4), _f32)], poly_axes=[0]), _make_harness("einsum", "0", lambda x: jnp.einsum("...i->...", x), [RandArg((3, 4), _f32)], poly_axes=[0]), _make_harness("einsum", "1", lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y), [RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)], poly_axes=[0, 0]), _make_harness("einsum", "2", lambda x, y: jnp.einsum("...ij,jk->...ik", x, y), [RandArg((3, 4, 5), _f32), RandArg((5, 6), _f32)], poly_axes=[0, None]), _make_harness("einsum", "3", # Reduced dimension is polymorphic
def oei_arrays(geom, basis, charges): """ Build one electron integral arrays (overlap, kinetic, and potential integrals) """ coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis) nbf = get_nbf(basis) nprim = coeffs.shape[0] max_am = jnp.max(ams) A_vals = jnp.zeros(2 * max_am + 1) # Save various AM distributions for indexing # Obtain all possible primitive quartet index combinations primitive_duets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim)) with loops.Scope() as s: s.oei = jnp.zeros((3, nbf, nbf)) s.a = 0 # center A angular momentum iterator s.b = 0 # center B angular momentum iterator for prim_duet in s.range(primitive_duets.shape[0]): p1, p2 = primitive_duets[prim_duet] coef = coeffs[p1] * coeffs[p2] aa, bb = exps[p1], exps[p2] atom1, atom2 = atoms[p1], atoms[p2] am1, am2 = ams[p1], ams[p2] A, B = geom[atom1], geom[atom2] ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2] gamma = aa + bb prefactor = jnp.exp(-aa * bb * jnp.dot(A - B, A - B) / gamma) P = (aa * A + bb * B) / gamma # Maximum angular momentum: hard coded #max_am = 3 # f function support # Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi. # We need 2+max_am since kinetic requires incrementing angluar momentum by +2 PA_pow = jnp.power( jnp.broadcast_to(P - A, (max_am + 3, 3)).T, jnp.arange(max_am + 3)) PB_pow = jnp.power( jnp.broadcast_to(P - B, (max_am + 3, 3)).T, jnp.arange(max_am + 3)) # For potential integrals, we need the difference between # the gaussian product center P and ALL atoms in the molecule, # and then take all possible powers up to 2*max_am. # We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed. # The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3 P_minus_geom = jnp.broadcast_to(P, geom.shape) - geom Pgeom_pow = jnp.power( jnp.transpose( jnp.broadcast_to( P_minus_geom, (2 * max_am + 1, geom.shape[0], geom.shape[1])), (1, 2, 0)), jnp.arange(2 * max_am + 1)) # All possible jnp.dot(P-atom,P-atom) rcp2 = jnp.einsum('ij,ij->i', P_minus_geom, P_minus_geom) # All needed (and unneeded, for am < max_am) boys function evaluations boys_arg = jnp.broadcast_to(rcp2 * gamma, (2 * max_am + 1, geom.shape[0])) boys_nu = jnp.tile(jnp.arange(2 * max_am + 1), (geom.shape[0], 1)).T boys_eval = boys(boys_nu, boys_arg) s.a = 0 for _ in s.while_range(lambda: s.a < dims[p1]): s.b = 0 for _ in s.while_range(lambda: s.b < dims[p2]): # Gather angular momentum and index la, ma, na = angular_momentum_combinations[s.a + ld1] lb, mb, nb = angular_momentum_combinations[s.b + ld2] # To only create unique indices, need to have separate indices arrays for i and j. i = indices[p1] + s.a j = indices[p2] + s.b # Compute one electron integrals and add to appropriate index overlap_int = overlap(la, ma, na, lb, mb, nb, aa, bb, PA_pow, PB_pow, prefactor) * coef kinetic_int = kinetic(la, ma, na, lb, mb, nb, aa, bb, PA_pow, PB_pow, prefactor) * coef potential_int = potential(la, ma, na, lb, mb, nb, aa, bb, PA_pow, PB_pow, Pgeom_pow, boys_eval, prefactor, charges, A_vals) * coef s.oei = jax.ops.index_add( s.oei, ([0, 1, 2], [i, i, i], [j, j, j]), (overlap_int, kinetic_int, potential_int)) s.b += 1 s.a += 1 S, T, V = s.oei[0], s.oei[1], s.oei[2] return S, T, V
def raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, rng=None): """Transforms model's predictions to semantically meaningful values. Args: raw: (num_rays, num_samples || num_importance, 4) prediction from model z_vals: (num_rays, num_samples || num_importance) integration time rays_d: (num_rays, 3) direction of each ray raw_noise_std: std of noise added for regularization white_bkgd: whether to use the alpha channel for white background rng: random key Returns: acc_map: (num_rays) sum of weights along each ray depth_map: (num_rays) estimated distance to object disp_map: (num_rays) disparity map (inverse of depth map) rgb_map: (num_rays, 3) estimated RGB color of a ray weights: (num_rays, num_samples || num_importance) weights assigned to each sampled color """ # compute 'distance' (in time) between each integration time along a ray dists = z_vals[..., 1:] - z_vals[..., :-1] # the 'distance' from the last integration time is infinity dists = jnp.concatenate( [dists, jnp.broadcast_to([1e10], dists[..., :1].shape)], axis=-1) dists = dists.astype(z_vals.dtype) # [num_rays, num_samples] # multiply each distance by the norm of its corresponding direction ray # to convert to real world distance (accounts for non-unit directions) dists = dists * jnp.linalg.norm(rays_d[..., None, :], axis=-1) # extract RGB of each sample position along each ray rgb = nn.sigmoid(raw[..., :3]) # [num_rays, num_samples, 3] # add noise to predictions for density, can be used to (this value is strictly between [0, 1]) # regularize network during training (prevents floater artifacts) noise = 0.0 if raw_noise_std > 0.0 and rng is not None: noise = random.normal(rng, raw[..., 3].shape) * raw_noise_std # predict density of each sample along each ray (alpha channel) # higher values imply higher likelihood of being absorbed at this point alpha = 1.0 - jnp.exp(-nn.relu(raw[..., 3] + noise) * dists) # compute weight for RGB of each sample along each ray # cumprod() is used to express the idea of the ray not having reflected up to this sample yet # weights = alpha * tf.math.cumprod(1.0 - alpha + 1e-10, axis=-1, exclusive=True) alpha_ = jnp.clip(1.0 - alpha, 1e-5, 1.0) weights = jnp.concatenate( [jnp.ones_like(alpha_[..., :1]), alpha_[..., :-1]], -1) weights = alpha * jnp.cumprod(weights, -1) # [num_rays, num_samples] # computed weighted color of each sample along each ray rgb_map = jnp.einsum("ij,ijk->ik", weights, rgb) # [num_rays, 3] # estimated depth map is expected distance depth_map = jnp.einsum("ij,ij->i", weights, z_vals) # [num_rays] # sum of weights along each ray (this value is in [0, 1] up to numerical error) acc_map = jnp.einsum("ij->i", weights) # [num_rays] # disparity map is inverse depth i_depth = depth_map / jnp.clip(acc_map, 1e-5) disp_map = 1.0 / jnp.clip(i_depth, 1e-5) # to composite onto a white background, use the accumulated alpha map if white_bkgd: rgb_map += 1.0 - acc_map[..., None] return { "rgb": rgb_map.astype(jnp.float32), "disp": disp_map.astype(jnp.float32), "acc": acc_map.astype(jnp.float32), "depth": depth_map.astype(jnp.float32), }, weights
def func(x): return np.einsum("ij,ij...->...", aux, x)
def restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, options, deriv_order=0, return_aux_data=False): # Load keyword options maxit = options['maxit'] damping = options['damping'] damp_factor = options['damp_factor'] spectral_shift = options['spectral_shift'] convergence = 1e-10 nelectrons = int(jnp.sum(nuclear_charges)) - charge ndocc = nelectrons // 2 # If we are doing MP2 or CCSD after, might as well use jit-compiled JK-build, since HF will not be memory bottleneck if return_aux_data: jk_build = jax.jit(jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None))) else: jk_build = jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1),(0,1)]), in_axes=(0,None)), in_axes=(0,None)) # Canonical orthogonalization via cholesky decomposition S, T, V, G = compute_integrals(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order, options) A = cholesky_orthogonalization(S) nbf = S.shape[0] # For slightly shifting eigenspectrum of transformed Fock for degenerate eigenvalues # (JAX cannot differentiate degenerate eigenvalue eigh) if spectral_shift: # Shifting eigenspectrum requires lower convergence. convergence = 1e-8 fudge = jnp.asarray(np.linspace(0, 1, nbf)) * convergence shift = jnp.diag(fudge) else: shift = jnp.zeros_like(S) H = T + V Enuc = nuclear_repulsion(geom.reshape(-1,3),nuclear_charges) D = jnp.zeros_like(H) def rhf_iter(F,D): E_scf = jnp.einsum('pq,pq->', F + H, D) + Enuc Fp = jnp.dot(A.T, jnp.dot(F, A)) Fp = Fp + shift eps, C2 = jnp.linalg.eigh(Fp) C = jnp.dot(A,C2) Cocc = C[:, :ndocc] D = jnp.dot(Cocc, Cocc.T) return E_scf, D, C, eps iteration = 0 E_scf = 1.0 E_old = 0.0 Dold = jnp.zeros_like(D) dRMS = 1.0 # Converge according to energy and DIIS residual to ensure eigenvalues and eigenvectors are maximally converged. # This is crucial for numerical stability for higher order derivatives of correlated methods. while ((abs(E_scf - E_old) > convergence) or (dRMS > convergence)): E_old = E_scf * 1 if damping: if iteration < 10: D = Dold * damp_factor + D * damp_factor Dold = D * 1 # Build JK matrix: 2 * J - K JK = 2 * jk_build(G, D) JK -= jk_build(G.transpose((0,2,1,3)), D) # Build Fock F = H + JK # Update convergence error if iteration > 1: diis_e = jnp.einsum('ij,jk,kl->il', F, D, S) - jnp.einsum('ij,jk,kl->il', S, D, F) diis_e = A.dot(diis_e).dot(A) dRMS = jnp.mean(diis_e**2)**0.5 # Compute energy, transform Fock and diagonalize, get new density E_scf, D, C, eps = rhf_iter(F,D) iteration += 1 if iteration == maxit: break print(iteration, " RHF iterations performed") # If many orbitals are degenerate, warn that higher order derivatives may be unstable tmp = jnp.round(eps,6) ndegen_orbs = tmp.shape[0] - jnp.unique(tmp).shape[0] if (ndegen_orbs / nbf) > 0.20: print("Hartree-Fock warning: More than 20% of orbitals have degeneracies. Higher order derivatives may be unstable due to eigendecomposition AD rule") if not return_aux_data: return E_scf else: return E_scf, C, eps, G
def fixed_pos_embedding(seq, dim): inv_freq = 1.0 / (10000**(np.arange(0, dim, 2) / dim)) sinusoid_inp = np.einsum("i , j -> i j", np.arange(seq), inv_freq) return np.sin(sinusoid_inp), np.cos(sinusoid_inp)
def f(scale): scaled_mat = scale * psd_mat chol = jnp.linalg.cholesky(scaled_mat) return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2)
Qt = jnp.eye(2) * 0.001 # Observed noise Rt = jnp.eye(2) * 0.01 mu0 = jnp.array([1, 1]) Sigma0 = jnp.eye(2) key = random.PRNGKey(314) kf = lds.ContinuousKalmanFilter(A, C, Qt, Rt, x0, Sigma0) sample_state, sample_obs, jump = kf.sample(key, x0, T, nsamples) mu_hist, V_hist, *_ = kf.filter(sample_obs, jump, dt) step = 0.1 vmin, vmax = -1.5, 1.5 + step X = np.mgrid[-1:1.5:step, vmin:vmax:step][::-1] X_dot = jnp.einsum("ij,jxy->ixy", A, X) fig, ax = plt.subplots() ax.plot(*sample_state.T, label="state space") ax.scatter(*sample_obs.T, marker="+", c="tab:green", s=60, label="observations") ax.scatter(*sample_state[0], c="black", zorder=3) field = ax.streamplot(*X, *X_dot, density=1.1, color="#ccccccaa") ax.legend() plt.axis("equal") ax.set_title("State Space") pml.savefig("kf-circle-state.pdf") fig, ax = plt.subplots() ax.plot(*mu_hist.T, c="tab:orange", label="Filtered") ax.scatter(*sample_obs.T, marker="+", s=60, c="tab:green", label="observations") ax.scatter(*mu_hist[0], c="black", zorder=3)
for n in range(N): for c in range(C): s = 0 for d in range(D): for k in range(K): for t in range(T): s += S[n, t, k] * W[k, d] * V[d, c] L[n, c] = s assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V)) path = np.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0] assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V, optimize=path)) import jax.numpy as jnp path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0] assert np.allclose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path)) # Use full student network from KOller and Friedman str = 'c,dc,gdi,si,lg,jls,hgj->' K = 5 cptC = np.random.randn(K) cptD = np.random.randn(K, K) cptG = np.random.randn(K, K, K) cptS = np.random.randn(K, K) cptL = np.random.randn(K, K) cptJ = np.random.randn(K, K, K) cptH = np.random.randn(K, K, K) cpts = [cptC, cptD, cptG, cptS, cptL, cptJ, cptH] path_info = np.einsum_path(str, *cpts, optimize='optimal') print(path_info[0] ) # 'einsum_path', (0, 1), (0, 5), (0, 4), (0, 3), (0, 2), (0, 1)]
def body(p, qkv): (q, k, v) = qkv p += jnp.einsum('...m,...d->...md', k, v, precision=precision) X_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision) return p, X_slice
def dot_product_attention(query: Array, key: Array, value: Array, bias: Optional[Array] = None, broadcast_dropout: bool = True, dropout_rng: Optional[PRNGKey] = None, dropout_rate: float = 0., deterministic: bool = False, dtype: Dtype = jnp.float32, precision: Optional[lax.Precision] = None): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. Note: query, key, value needn't have any batch dimensions. Args: query: queries for calculating attention with shape of `[batch..., q_length, num_heads, qk_depth_per_head]`. key: keys for calculating attention with shape of `[batch..., kv_length, num_heads, qk_depth_per_head]`. value: values to be used in attention with shape of `[batch..., kv_length, num_heads, v_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the shape: `[batch..., num_heads, q_length, kv_length]` This can be used for incorporating causal masks, padding masks, proximity bias, etc. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: float32) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[batch..., length, num_heads, v_depth_per_head]`. """ assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank." assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( "q, k, v batch dims must match.") assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( "q, k, v num_heads must match.") assert key.shape[-3] == value.shape[-3], "k, v lengths must match." assert query.shape[-1] == key.shape[-1], "q, k depths must match." # calculate attention matrix depth = query.shape[-1] query = query / jnp.sqrt(depth).astype(dtype) # attn weight shape is (batch..., num_heads, q_length, kv_length) attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key, precision=precision) # apply attention bias: masking, dropout, proximity bias, etc. if bias is not None: attn_weights = attn_weights + bias # normalize the attention weights attn_weights = jax.nn.softmax(attn_weights).astype(dtype) # apply attention dropout if not deterministic and dropout_rate > 0.: keep_prob = 1.0 - dropout_rate if broadcast_dropout: # dropout is broadcast across the batch + head dimensions dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:] keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier # return weighted sum over values for each query position return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, precision=precision)
def body(p, qk): q, k = qk p += k x = jnp.einsum('...m,...m->...', q, p, precision=precision) return p, x