def beam_search_body_fn(state, input_ids_length=1): """beam search state update fn.""" # 1. Forward current tokens # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # unflatten beam dimension # Unflatten beam dimension in attention cache arrays input_token = flatten_beam_dim( lax.dynamic_slice( state.running_sequences, (0, 0, state.cur_len - input_ids_length), (batch_size, num_beams, input_ids_length), )) model_outputs = model(input_token, params=params, **state.model_kwargs) logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) cache = jax.tree_map( lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams ), model_outputs.past_key_values) # adapt logits for FlaxMarianMTModel logits = self._adapt_logits_for_beam_search(logits) # 2. Compute log probs # get log probabilities from logits, # process logits with processors (*e.g.* min_length, ...), and # add new logprobs to existing running logprobs scores. log_probs = jax.nn.log_softmax(logits) log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) vocab_size = log_probs.shape[2] log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) # 3. Retrieve top-K # Each item in batch has num_beams * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. # Gather the top 2*K scores from _all_ beams. # Gather 2*k top beams. # Recover the beam index by floor division. # Recover token id by modulo division and expand Id array for broadcasting. # Update sequences for the 2*K top-k new sequences. beams_to_keep = 2 * num_beams topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size topk_running_sequences = gather_beams(state.running_sequences, topk_beam_indices, batch_size, beams_to_keep) topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) # 4. Check which sequences have ended # Update current sequences: # Did any of these sequences reach an end marker? # To prevent these just finished sequences from being added to the current sequences # set of active beam search sequences, set their log probs to a very large # negative value. did_topk_just_finished = topk_sequences[:, :, state. cur_len] == eos_token_id running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array( -1.0e7) # 5. Get running sequences scores for next # Determine the top k beam indices (from top 2*k beams) from log probs # and gather top k beams (from top 2*k beams). next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) next_running_sequences, next_running_scores = gather_beams( [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams) # 6. Process topk logits # Further process log probs: # - add length penalty # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) beams_in_batch_are_full = (jnp.broadcast_to( state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) & early_stopping) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full topk_log_probs += add_penalty * np.array(-1.0e7) # 7. Get scores, sequences, is sentence finished for next. # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams merged_sequences = jnp.concatenate( [state.sequences, topk_sequences], axis=1) merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) merged_is_sent_finished = jnp.concatenate( [state.is_sent_finished, did_topk_just_finished], axis=1) topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) next_sequences, next_scores, next_is_sent_finished = gather_beams( [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams) # 8. Update model kwargs. # Determine the top k beam indices from the original set of all beams. # With these, gather the top k beam-associated caches. next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) model_outputs["past_key_values"] = jax.tree_map( lambda x: flatten_beam_dim(x), next_cache) next_model_kwargs = self.update_inputs_for_generation( model_outputs, state.model_kwargs) return BeamSearchState( cur_len=state.cur_len + 1, running_scores=next_running_scores, running_sequences=next_running_sequences, scores=next_scores, sequences=next_sequences, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, )
def flip(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType: return type(self)(np.flip(self.raw, axis=axis))
def flip(self, a, axis=None): return jnp.flip(a, axis)
def sort(tensor, axis, descending=False): if descending: return np.flip(np.sort(tensor, axis=axis), axis=axis) else: return np.sort(tensor, axis=axis)
def flip(x, axis=None): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.flip(x, axis=axis))
def flip(x, axis=None, _=None): if isinstance(axis, list) or isinstance(axis, tuple): raise Exception('Jax does not support flip() across multiple indices') return _jnp.flip(x, axis)
def get_index(x): x = 2 * (grid.lower - x.flatten()) / grid.size + 1 idx = (grid.shape * np.arccos(x)) // np.pi idx = np.nan_to_num(idx, nan=grid.shape) return (*np.flip(np.uint32(idx)), )
def rewards_to_go(rewards, mask, gamma=0.99): r"""Computes rewards to go. Reward to go is defined as follows, the discounted reward that we have to yet collect, going forward from this point, i.e.: r2g_t = \sum_{l=0}^{\infty} (\gamma^{l} * reward_{t+l}) Args: rewards: np.ndarray of shape (B, T) of rewards. mask: np.ndarray of shape (B, T) of mask for the rewards. gamma: float, discount factor. Returns: rewards to go, np.ndarray of shape (B, T). """ B, T = rewards.shape # pylint: disable=invalid-name,unused-variable masked_rewards = rewards * mask # (B, T) # The lax.scan version of this is slow, but we still show it here for # completeness. # rewards_rev = np.flip(masked_rewards, axis=1) # (B, T) flipped on time. # rrt = np.transpose(rewards_rev) # (T, B) transpose to scan over time. # # def discounting_add(carry, reward): # x = reward + (gamma * carry) # return x, x # # _, ys = lax.scan(discounting_add, # np.zeros_like(rrt[0], dtype=np.float32), # rrt.astype(np.float32)) # # # ys is (T, B) and T is in reverse order. # return np.flip(np.transpose(ys), axis=1) # We use the following recurrence relation, derived from the equation above: # # r2g[t+1] = (r2g[t] - r[t]) / gamma # # This means we'll need to calculate r2g[0] first and then r2g[1] and so on .. # # **However** this leads to overflows for long sequences: r2g[t] - r[t] > 0 # and gamma < 1.0, so the division keeps increasing. # # So we just run the recurrence in reverse, i.e. # # r2g[t] = r[t] + (gamma*r2g[t+1]) # # This is much better, but might have lost updates since the (small) rewards # at earlier time-steps may get added to a (very?) large sum. # Compute r2g_{T-1} at the start and then compute backwards in time. r2gs = [masked_rewards[:, -1]] # Go from T-2 down to 0. for t in reversed(range(T - 1)): r2gs.append(masked_rewards[:, t] + (gamma * r2gs[-1])) # The list should have length T. assert T == len(r2gs) # First we stack them in the correct way to make it (B, T), but these are # still from newest (T-1) to oldest (0), so then we flip it on time axis. return np.flip(np.stack(r2gs, axis=1), axis=1)
def reverse_line(self, line, b): return jax.lax.cond(b==1, lambda z : z, lambda z : jnp.flip(z,0), line)
def get_index(x): h = grid.size / grid.shape idx = (x.flatten() - grid.lower) // h idx = idx % grid.shape return (*np.flip(np.uint32(idx)), )
def not_0_or_2(): perm = list(range(m.ndim)) perm[ax1], perm[ax2] = perm[ax2], perm[ax1] return jnp.where(k == 1, jnp.transpose(jnp.flip(m, ax2), perm), jnp.flip(jnp.transpose(m, perm), ax2))
def not_0(): return jnp.where(k == 2, jnp.flip(jnp.flip(m, ax1), ax2), not_0_or_2())
def _reverse(tensor, axis, name=None): # pylint: disable=unused-argument if np.array(axis).ndim == 0: return np.flip(tensor, axis) for ax in axis: tensor = np.flip(tensor, ax) return tensor
def argsort(input, dim, descending): _sorted = jnp.argsort(input, axis=dim) if descending is True: _sorted = jnp.flip(_sorted, axis=dim) return _sorted
def get_index(x): h = grid.size / grid.shape idx = (x.flatten() - grid.lower) // h idx = np.where((idx < 0) | (idx > grid.shape), grid.shape, idx) return (*np.flip(np.uint32(idx)), )
def testFlip(self, shape, dtype, axis, rng): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) lnp_op = lambda x: lnp.flip(x, axis) onp_op = lambda x: onp.flip(x, axis) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
def _flip_and_average_on_center(array): """Flips and averages array on the center.""" return (array + jnp.flip(array)) / 2
def tei_array(geom, basis): """ Build two electron integral array from a jax.numpy array of the cartesian geometry in Bohr, and a basis dictionary as defined by basis_utils.build_basis_set We have to loop over primitives rather than shells because JAX needs intermediates to be consistent sizes in order to compile. """ # Smush primitive data together into vectors coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis) nbf = get_nbf(basis) max_am = jnp.max(ams) max_am_idx = max_am * 4 + 1 #TODO add excpetion raise if angular momentum is too high B_vals = jnp.zeros(4*max_am+1) nprim = coeffs.shape[0] # Obtain all possible primitive quartet index combinations primitive_quartets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim)) #print("Number of basis functions: ", nbf) #print("Number of primitve quartets: ", primitive_quartets.shape[0]) #TODO Experimental: precompute quantities and lookup inside loop # Compute all possible Gaussian products for this basis set aa_plus_bb = jnp.broadcast_to(exps, (nprim,nprim)) + jnp.transpose(jnp.broadcast_to(exps, (nprim,nprim)), (1,0)) aa_times_A = jnp.einsum('i,ij->ij', exps, geom[atoms]) aaxA_plus_bbxB = aa_times_A[:,None,:] + aa_times_A[None,:,:] gaussian_products = jnp.einsum('ijk,ij->ijk', aaxA_plus_bbxB, 1/aa_plus_bb) # Compute all rab2 (rcd2), every possible jnp.dot(A-B,A-B) natom = geom.shape[0] tmpA = jnp.broadcast_to(geom, (natom,natom,3)) AminusB = (tmpA - jnp.transpose(tmpA, (1,0,2))) AmBdot = jnp.einsum('ijk,ijk->ij', AminusB, AminusB) # shape: (natom,natom) # Compute all differences between gaussian product centers with all atom centers tmpP = jnp.tile(gaussian_products, natom).reshape(nprim,nprim,natom,3) PminusA = tmpP - jnp.broadcast_to(geom, tmpP.shape) # Commpute all powers (up to max_am) of differences between gaussian product centers and atom centers # Shape: (nprim, nprim, natom, 3, max_am+1). In loop index PA_pow as [p1,p2,atoms[p1],:,:] PminusA_pow = jnp.power(jnp.transpose(jnp.broadcast_to(PminusA, (max_am+1,nprim,nprim,natom,3)), (1,2,3,4,0)), jnp.arange(max_am+1)) with loops.Scope() as s: s.G = jnp.zeros((nbf,nbf,nbf,nbf)) s.a = 0 # center A angular momentum iterator s.b = 0 # center B angular momentum iterator s.c = 0 # center C angular momentum iterator s.d = 0 # center D angular momentum iterator # Loop over primitive quartets, compute integral, add to appropriate index in G for prim_quar in s.range(primitive_quartets.shape[0]): # Load in primitive indices, coeffs, exponents, centers, angular momentum index, and leading placement index in TEI array p1,p2,p3,p4 = primitive_quartets[prim_quar] coef = coeffs[p1] * coeffs[p2] * coeffs[p3] * coeffs[p4] aa, bb, cc, dd = exps[p1], exps[p2], exps[p3], exps[p4] ld1, ld2, ld3, ld4 = am_leading_indices[ams[p1]],am_leading_indices[ams[p2]],am_leading_indices[ams[p3]],am_leading_indices[ams[p4]] idx1, idx2, idx3, idx4 = indices[p1],indices[p2],indices[p3],indices[p4], #A, B, C, D = geom[atoms[p1]], geom[atoms[p2]], geom[atoms[p3]], geom[atoms[p4]] # Compute common intermediates before looping over AM distributions. # Avoids redundant recomputations/reassignment for all classes other than (ss|ss). #AB = A - B #CD = C - D #rab2 = jnp.dot(AB,AB) #rcd2 = jnp.dot(CD,CD) #P = (aa * A + bb * B) / gamma1 #Q = (cc * C + dd * D) / gamma2 gamma1 = aa + bb gamma2 = cc + dd #TODO P = gaussian_products[p1,p2] Q = gaussian_products[p3,p4] rab2 = AmBdot[atoms[p1],atoms[p2]] rcd2 = AmBdot[atoms[p3],atoms[p4]] #PA = PminusA[p1,p2,atoms[p1]] #PB = PminusA[p1,p2,atoms[p2]] #QC = PminusA[p3,p4,atoms[p3]] #QD = PminusA[p3,p4,atoms[p4]] #TODO PQ = P - Q rpq2 = jnp.dot(PQ,PQ) delta = 0.25*(1/gamma1+1/gamma2) boys_arg = 0.25 * rpq2 / delta boys_eval = boys(jnp.arange(max_am_idx), boys_arg) # Need all powers of Pi-Ai,Pi-Bi,Qi-Ci,Qi-Di (i=x,y,z) up to max_am and Qi-Pi up to max_am_idx # note: this computes unncessary quantities for lower angular momentum, # but avoids repeated computation of the same quantities in loops for higher angular momentum #PA_pow = jnp.power(jnp.broadcast_to(P-A, (max_am+1,3)).T, jnp.arange(max_am+1)) #PB_pow = jnp.power(jnp.broadcast_to(P-B, (max_am+1,3)).T, jnp.arange(max_am+1)) #QC_pow = jnp.power(jnp.broadcast_to(Q-C, (max_am+1,3)).T, jnp.arange(max_am+1)) #QD_pow = jnp.power(jnp.broadcast_to(Q-D, (max_am+1,3)).T, jnp.arange(max_am+1)) PA_pow = PminusA_pow[p1,p2,atoms[p1],:,:] PB_pow = PminusA_pow[p1,p2,atoms[p2],:,:] QC_pow = PminusA_pow[p3,p4,atoms[p3],:,:] QD_pow = PminusA_pow[p3,p4,atoms[p4],:,:] QP_pow = jnp.power(jnp.broadcast_to(Q-P, (max_am_idx,3)).T, jnp.arange(max_am_idx)) # Gamma powers are negative, up to -(l1+l2). # Make array such that the given negative index returns the same negative power. g1_pow = jnp.power(4*gamma1, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) g2_pow = jnp.power(4*gamma2, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) oodelta_pow = jnp.power(1 / delta, jnp.arange(max_am_idx)) # l1 + l2 + l3 + l4 + 1 prefactor = 34.986836655249726 / (gamma1*gamma2*jnp.sqrt(gamma1+gamma2)) \ * jnp.exp(-aa*bb*rab2/gamma1 + -cc*dd*rcd2/gamma2) * coef # TODO is there symmetry here? s.a = 0 for _ in s.while_range(lambda: s.a < dims[p1]): s.b = 0 for _ in s.while_range(lambda: s.b < dims[p2]): s.c = 0 for _ in s.while_range(lambda: s.c < dims[p3]): s.d = 0 for _ in s.while_range(lambda: s.d < dims[p4]): # Collect angular momentum and index in G la, ma, na = angular_momentum_combinations[s.a + ld1] lb, mb, nb = angular_momentum_combinations[s.b + ld2] lc, mc, nc = angular_momentum_combinations[s.c + ld3] ld, md, nd = angular_momentum_combinations[s.d + ld4] i = idx1 + s.a j = idx2 + s.b k = idx3 + s.c l = idx4 + s.d # Compute the primitive quartet tei and add to appropriate index in G Bx = B_array(la,lb,lc,ld,PA_pow[0],PB_pow[0],QC_pow[0],QD_pow[0],QP_pow[0],g1_pow,g2_pow,oodelta_pow,B_vals) By = B_array(ma,mb,mc,md,PA_pow[1],PB_pow[1],QC_pow[1],QD_pow[1],QP_pow[1],g1_pow,g2_pow,oodelta_pow,B_vals) Bz = B_array(na,nb,nc,nd,PA_pow[2],PB_pow[2],QC_pow[2],QD_pow[2],QP_pow[2],g1_pow,g2_pow,oodelta_pow,B_vals) with loops.Scope() as S: S.primitive = 0. S.I = 0 S.J = 0 S.K = 0 for _ in S.while_range(lambda: S.I < la + lb + lc + ld + 1): S.J = 0 tmp = Bx[S.I] for _ in S.while_range(lambda: S.J < ma + mb + mc + md + 1): S.K = 0 tmp *= By[S.J] for _ in S.while_range(lambda: S.K < na + nb + nc + nd + 1): tmp *= Bz[S.K] * boys_eval[S.I + S.J + S.K] S.primitive += tmp S.K += 1 S.J += 1 S.I += 1 tei = prefactor * S.primitive s.G = jax.ops.index_add(s.G, jax.ops.index[i,j,k,l], tei) s.d += 1 s.c += 1 s.b += 1 s.a += 1 return s.G
def beam_search_loop_body_fn(state): """Beam search loop state update function.""" # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # --> [batch * beam, 1] flat_ids = flatten_beam_dim( lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1))) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree_map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache) # unflatten beam dimension # [batch * beam, vocab] --> [batch, beam, vocab] logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree_map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits candidate_log_probs = jax.nn.log_softmax(logits) # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] log_probs = (candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] # Each item in batch has beam_size * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. beams_to_keep = 2 * beam_size # Flatten beam and vocab dimensions. flat_log_probs = log_probs.reshape( (batch_size, beam_size * vocab_size)) # Gather the top 2*K scores from _all_ beams. # --> [batch, 2*beams], [batch, 2*beams] topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep) # Recover the beam index by floor division. topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams. # --> [batch, 2*beams, length] topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size, beams_to_keep) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. # --> [batch, 2*beams, 1] topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids, (0, 0, state.cur_index + 1)) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. new_log_probs = topk_log_probs + newly_finished * NEG_INF # Determine the top k beam indices (from top 2*k beams) from log probs. # --> [batch, beams] _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) new_topk_indices = jnp.flip(new_topk_indices, axis=1) # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] top_alive_seq, top_alive_log_probs = gather_beams( [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices, batch_size, beam_size) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} top_alive_cache = gather_beams(new_cache, top_alive_indices, batch_size, beam_size) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # Mask out the still unfinished sequences by adding large negative value. # --> [batch, 2*beams] new_scores += (~newly_finished) * NEG_INF # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] [state.finished_seqs, topk_seq], axis=1) finished_scores = jnp.concatenate( # --> [batch, 3*beams] [state.finished_scores, new_scores], axis=1) finished_flags = jnp.concatenate( # --> [batch, 3*beams] [state.finished_flags, newly_finished], axis=1) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( gather_topk_beams([finished_seqs, finished_scores, finished_flags], finished_scores, batch_size, beam_size)) return BeamState(cur_index=state.cur_index + 1, live_logprobs=top_alive_log_probs, finished_scores=top_finished_scores, live_seqs=top_alive_seq, finished_seqs=top_finished_seq, finished_flags=top_finished_flags, cache=top_alive_cache)
if __name__ == '__main__': # Test key = random.PRNGKey(42) from glm_py import GLMPy N = 2 M = 100 dh = 2 ds = 8 p = {'N': N, 'M': M, 'dh': dh, 'ds': ds, 'dt': 0.1, 'n': 0, 'N_lim': N, 'M_lim': M} w = random.normal(key, shape=(N, N)) * 0.001 h = random.normal(key, shape=(N, dh)) * 0.001 k = random.normal(key, shape=(N, ds)) * 0.001 b = random.normal(key, shape=(N, 1)) * 0.001 theta = {'h': np.flip(h, axis=1), 'w': w, 'b': b, 'k': k} model = GLMJax(p, theta) sN = 8 # data = onp.random.randn(sN, 2) # onp.zeros((8, 50)) stim = onp.random.randn(ds, 2) print(model.ll(data, stim)) def gen_ref(): ow = onp.asarray(w)[:sN, :sN] oh = onp.asarray(h)[:sN, ...] ok = onp.asarray(k)[:sN, ...] ob = onp.asarray(b)[:sN, ...] p = {'numNeurons': sN, 'hist_dim': dh, 'numSamples': M, 'dt': 0.1, 'stim_dim': ds}
def __call__( self, input_values, attention_mask=None, mask_time_indices=None, deterministic=True, output_attentions=None, output_hidden_states=None, return_dict=None, ): """ Returns: Example:: >>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model >>> from datasets import load_dataset >>> import soundfile as sf >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") >>> def map_to_array(batch): >>> speech, _ = sf.read(batch["file"]) >>> batch["speech"] = speech >>> return batch >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") >>> ds = ds.map(map_to_array) >>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 >>> hidden_states = model(input_values).last_hidden_state """ extract_features = self.feature_extractor(input_values) if attention_mask is not None: # compute real output lengths according to convolution formula output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4")) attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype) # these two operations makes sure that all values # before the output lengths indices are attended to attention_mask = jax.ops.index_update( attention_mask, jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1], 1 ) attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) if mask_time_indices is not None: # apply SpecAugment along time axis with given indices hidden_states = jnp.where( jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape), jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape), hidden_states, ) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] return FlaxWav2Vec2BaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )