def init(self, params): shape = params.shape slots = [] if self._factored and len(shape) >= 2: v_row = np.zeros(shape[:-1], dtype=np.float32) v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32) slots.extend([v_row, v_col]) else: v = np.zeros_like(params) slots.append(v) if self._do_momentum: m = np.zeros_like(params) slots.append(m) return slots
def init(self, x): shape = x.shape state = [] if self._factored and len(shape) >= 2: v_row = np.zeros(shape[:-1], dtype=np.float32) v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32) state.extend([v_row, v_col]) else: v = np.zeros_like(x) state.append(v) if self._beta1: m = np.zeros_like(x) state.append(m) return state
def backward(self, inputs, output, ct, params=(), state=(), rng=None, **kwargs): del output, params, state _, (qk_ct, v_ct) = self.batch_call_and_or_grad( inputs[0], inputs[2], return_output=False, ct=ct, rng=rng) inputs_ct = (qk_ct, np.zeros_like(inputs[1]), v_ct) return inputs_ct, ()
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if backend.get_name() == 'jax': mask = jax.lax.tie_in(dots, mask) dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def _update_diagonal(self, step, g, x, m, v): v[0] += g * g preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]), np.zeros_like(v[0])) preconditioned_g = preconditioner * g m = (1 - self._momentum) * preconditioned_g + self._momentum * m x = x - self.step_size(step) * m return x, (m, v)
def _update_diagonal(self, grads, params, m, v, opt_params): (learning_rate, momentum) = opt_params v[0] += grads * grads preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]), np.zeros_like(v[0])) preconditioned_grads = preconditioner * grads m = (1 - momentum) * preconditioned_grads + momentum * m params = params - (learning_rate * m).astype(params.dtype) return params, (m, v)
def Dropout(x, params, rate=0.0, mode='train', rng=None, **kwargs): """Layer construction function for a dropout layer with given rate.""" del params, kwargs if rng is None: msg = ('Dropout layer requires apply_fn to be called with a rng keyword ' 'argument. That is, instead of `Dropout(params, inputs)`, call ' 'it like `Dropout(params, inputs, rng=key)`.') raise ValueError(msg) if rate >= 1.0: raise ValueError('Dropout rate (%f) must be lower than 1.' % rate) if mode == 'train' and rate > 0.0: keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) return np.where(keep, x / (1.0 - rate), np.zeros_like(x)) else: return x
def call(self, x, params, state, rng=None, **unused_kwargs): """Execute dropout.""" del params rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] if rng is None: msg = ('Dropout layer requires apply_fn to be called with a rng keyword ' 'argument. That is, instead of `Dropout(params, inputs)`, call ' 'it like `Dropout(params, inputs, rng=key)`.') raise ValueError(msg) if self._mode != 'train': return x, state keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) return np.where(keep, x / (1.0 - rate), np.zeros_like(x)), state
def ShiftRight(x, **unused_kwargs): """Layer to shift the tensor to the right by padding on axis 1.""" if not isinstance(x, (list, tuple)): # non-chunked inputs pad_widths = [(0, 0), (1, 0)] padded = np.pad(x, pad_widths, mode='constant') return padded[:, :-1] # Handling chunked inputs. Recall that the list of chunks represents a big # sequence (the concatenation of the chunks). We want to shift that sequence, # so we put a 0 in the beginning of the first chunk and the last element of # that chunk is used as the new first element of the next chunk, and so on. padded = [] last_value = np.zeros_like(x[0][:, -1]) for chunk in x: padded_chunk = np.concatenate([last_value[:, np.newaxis], chunk], axis=1) last_value = chunk[:, -1] padded.append(padded_chunk[:, :-1]) return padded
def _update_sketched(self, grads, params, m, v, opt_params): """Update for higher-rank parameters.""" (learning_rate, momentum) = opt_params shape = params.shape rank = len(shape) reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m params = params - (learning_rate * m).astype(params.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return params, (m, v)
def _update_sketched(self, step, g, x, m, v): """Update for higher-rank parameters.""" shape = x.shape rank = len(shape) reshaped_accumulators = [ np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank) ] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += g * g accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = g * accumulator_inv_sqrt m = (1.0 - self._momentum) * preconditioned_gradient + self._momentum * m x = x - self.step_size(step) * m for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return x, (m, v)
def call_and_grad(self, inputs, ct, rng=None, **kwargs): del kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, ignored_k, v = inputs seqlen = qk.shape[-2] # qk/v are n_batch*n_heads, seqlen, d_head # bins are n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int((self.n_bins + 1) * seqlen) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) sqk = np.take_along_axis(qk, sjoint_t[:, :, None], axis=-2) sv = np.take_along_axis(v, sjoint_t[:, :, None], axis=-2) if ct is not None: so_ct = np.take_along_axis(ct, sjoint_t[:, :, None], axis=-2) @jax.jit def binned_attn(sqk, sv): # pylint: disable=invalid-name """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) bo = np.matmul(dots, bv) so = unchunk_vectors(bo) return so @jax.jit def binned_attn_vjp(sqk, sv, so_ct): # pylint: disable=invalid-name so, vjpfun = jax.vjp(binned_attn, sqk, sv) sqkv_ct = vjpfun(so_ct) return so, sqkv_ct if ct is None: so = binned_attn(sqk, sv) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2) return out, None else: # Jax can construct a backward pass automatically, but it's about 2x # slower than writing our own. The main reason is that the backward pass # of gather is in general a scatter operation, but we know we're dealing # with permutations so we use gather for the backward pass too. so, (sqk_ct, sv_ct) = binned_attn_vjp(sqk, sv, so_ct) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2) qk_ct = np.take_along_axis(sqk_ct, undo_sort[:, :, None], axis=-2) v_ct = np.take_along_axis(sv_ct, undo_sort[:, :, None], axis=-2) return out, (qk_ct, np.zeros_like(ignored_k), v_ct)
def call_and_grad(self, inputs, ct, rng=None, **kwargs): del kwargs query, key, value = inputs depth = np.shape(query)[-1] do_backprop = ct is not None # jax uses the term cotangent (ct) to refer to gradient signals, and # vector-Jacobian product (vjp) for back-propagation through a layer. def make_mask(N, M, k): # pylint: disable=invalid-name """Constructs a slice of the causal attention mask. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = np.arange(N, dtype=np.int32) y = np.arange(M, dtype=np.int32) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name """Forward pass for a subset of the query vectors.""" dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout is not None and self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) slice_rng = jax.random.fold_in(rng, q_loop_idx) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier out_slice = np.matmul(dots, value) return out_slice def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): # pylint: disable=invalid-name # Capture q_loop_idx to avoid calculated gradients wrt. it. def forward_slice_with_q_loop_idx(query_slice, key, value): # pylint: disable=invalid-name return forward_slice(query_slice, q_loop_idx, key, value) output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx, query_slice, key, value) return output_slice, vjpfun(ct_slice) q_loop_idx = np.zeros((), dtype=np.int32) q_loop_max = query.shape[-2] q_loop_stride = self._loop_stride assert q_loop_max % q_loop_stride == 0, ( 'Stride must evenly divide the number of query elements.') out_accum = np.zeros_like(query) if do_backprop: query_ct_accum = np.zeros_like(query) key_ct_accum = np.zeros_like(key) value_ct_accum = np.zeros_like(value) init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: init_vals = (q_loop_idx, out_accum) def cond_fun(vals): # pylint: disable=invalid-name q_loop_idx = vals[0] return jax.lax.lt(q_loop_idx, q_loop_max) def body_fun(vals): # pylint: disable=invalid-name """Compute a slice of the attention mechanism.""" if do_backprop: (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) = vals else: q_loop_idx, out_accum = vals query_slice = jax.lax.dynamic_slice_in_dim(query, q_loop_idx, q_loop_stride, axis=-2) if do_backprop: ct_slice = jax.lax.dynamic_slice_in_dim(ct, q_loop_idx, q_loop_stride, axis=-2) out_slice, partial_ct = forward_and_vjp_slice( query_slice, q_loop_idx, key, value, ct_slice) query_ct_accum = jax.lax.dynamic_update_slice_in_dim( query_ct_accum, partial_ct[0], q_loop_idx, axis=-2) key_ct_accum = key_ct_accum + partial_ct[1] value_ct_accum = value_ct_accum + partial_ct[2] else: out_slice = forward_slice(query_slice, q_loop_idx, key, value) out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum, out_slice, q_loop_idx, axis=-2) q_loop_idx = q_loop_idx + q_loop_stride if do_backprop: return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: return (q_loop_idx, out_accum) final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if not do_backprop: return final_vals[1], None else: return final_vals[1], final_vals[2:]
def batch_call_and_or_grad(self, qk, v, ct=None, return_output=True, rng=None): assert return_output or ct is not None, 'No work to perform!' # pylint: disable=protected-access stash_buckets = (return_output and ct is None and base.Layer._STASH_IN is not None) if return_output and ct is not None and base.Layer._STASH_OUT is not None: buckets = base.Layer._STASH_OUT.pop(self) else: buckets = None # pylint: enable=protected-access # The approach here is to perform attention for one batch element and head # at a time. Note that there is absolutely no interaction across examples or # heads: this layer has no parameters, and hashing patterns are also # different across examples/heads. As a result, batching doesn't give any # performance gains except in the case of accelerator under-utilization. We # assume that hash-based attention will be applied primarily to long # sequences, where unbatched attention for a single head has sufficient # computation to fill up the accelerator. batch_loop_idx = np.zeros((), dtype=np.int32) batch_loop_max = qk.shape[0] init_vals = (batch_loop_idx,) if return_output: out_accum = np.zeros_like(qk) init_vals = init_vals + (out_accum,) if stash_buckets: buckets_accum = np.zeros( [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32) init_vals = init_vals + (buckets_accum,) if ct is not None: qk_ct_accum = np.zeros_like(qk) v_ct_accum = np.zeros_like(v) init_vals = init_vals + (qk_ct_accum, v_ct_accum) def cond_fun(vals): batch_loop_idx = vals[0] return jax.lax.lt(batch_loop_idx, batch_loop_max) def body_fun(vals): """Performs attention for a single batch element and head.""" batch_loop_idx = vals[0] if self._prng is None: hash_rng = jax.random.fold_in(rng, batch_loop_idx) else: # TODO(kitaev): Maybe use the same RNG across examples (but not heads)? hash_rng = jax.random.fold_in(self._prng, batch_loop_idx) qk_slice = jax.lax.dynamic_index_in_dim( qk, batch_loop_idx, axis=0, keepdims=False) v_slice = jax.lax.dynamic_index_in_dim( v, batch_loop_idx, axis=0, keepdims=False) if buckets is None: buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng) else: buckets_slice = jax.lax.dynamic_index_in_dim( buckets, batch_loop_idx, axis=0, keepdims=False) if ct is None: out_slice = self.single_call( qk_slice, v_slice, buckets_slice, hash_rng=hash_rng) else: def _do_single_call(qk_slice, v_slice): return self.single_call( qk_slice, v_slice, buckets_slice, hash_rng=hash_rng) ct_slice = jax.lax.dynamic_index_in_dim( ct, batch_loop_idx, axis=0, keepdims=False) out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice) qk_ct_slice, v_ct_slice = vjpfun(ct_slice) new_vals = (batch_loop_idx + 1,) if return_output: out_accum = vals[1] out_accum = jax.lax.dynamic_update_index_in_dim( out_accum, out_slice, batch_loop_idx, axis=0) new_vals = new_vals + (out_accum,) if stash_buckets: buckets_accum = vals[2] buckets_accum = jax.lax.dynamic_update_index_in_dim( buckets_accum, buckets_slice, batch_loop_idx, axis=0) new_vals = new_vals + (buckets_accum,) if ct is not None: qk_ct_accum, v_ct_accum = vals[-2:] qk_ct_accum = jax.lax.dynamic_update_index_in_dim( qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0) v_ct_accum = jax.lax.dynamic_update_index_in_dim( v_ct_accum, v_ct_slice, batch_loop_idx, axis=0) new_vals = new_vals + (qk_ct_accum, v_ct_accum) return new_vals final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if return_output: out = final_vals[1] else: out = None if stash_buckets: base.Layer._STASH_IN[self] = final_vals[2] # pylint: disable=protected-access if ct is not None: input_ct = final_vals[-2:] else: input_ct = None return out, input_ct
def init(self, x): vs = [np.zeros(sz, dtype=x.dtype) for sz in x.shape] return (np.zeros_like(x), vs)
def init(self, x): m = np.zeros_like(x) v = np.zeros_like(x) return m, v
def ParametricRelu(x, a=1., **unused_kwargs): return np.maximum(a * x, np.zeros_like(x))
def forward_and_vjp(self, inputs, ct, params=(), **kwargs): # This is the core of the memory-efficient attention implementation, where # we use the jax.lax.while_loop primitive to compute attention for a small # set of query positions at a time. Note how in the backwards pass, we # compute both the forward direction (to recover the previous layer's # activations) and the backward direction simultaneously. This allows us to # only use a single loop, where the inner portion of the loop does a slice # of the forward+backward joint computation. Unfortunately we have had to # introduce a large number of wrapper classes (including # ReversibleAttentionHalfResidual and ApplyAttentionWrapper) for the sole # purpose of connecting this implementation of forward_and_vjp with the core # backprop implementation. query, key, value = inputs depth = np.shape(query)[-1] do_backprop = ct is not None def make_mask(N, M, k): x = np.arange(N, dtype=np.int32) y = np.arange(M, dtype=np.int32) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def forward_slice(query_slice, q_loop_idx, key, value): """Forward pass for a subset of the query vectors.""" dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - dots.max(axis=-1, keepdims=True)) dots = dots / dots.sum(axis=-1, keepdims=True) out_slice = np.matmul(dots, value) return out_slice def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): output_slice, vjpfun = jax.vjp(forward_slice, query_slice, q_loop_idx, key, value) return output_slice, vjpfun(ct_slice) q_loop_idx = np.zeros((), dtype=np.int32) q_loop_max = query.shape[2] q_loop_stride = self._loop_stride assert q_loop_max % q_loop_stride == 0, ( 'Stride must evenly divide the number of query elements.') out_accum = np.zeros_like(query) if do_backprop: query_ct_accum = np.zeros_like(query) key_ct_accum = np.zeros_like(key) value_ct_accum = np.zeros_like(value) init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: init_vals = (q_loop_idx, out_accum) def cond_fun(vals): q_loop_idx = vals[0] return jax.lax.lt(q_loop_idx, q_loop_max) def body_fun(vals): """Compute a slice of the attention mechanism.""" if do_backprop: (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) = vals else: q_loop_idx, out_accum = vals query_slice = jax.lax.dynamic_slice_in_dim(query, q_loop_idx, q_loop_stride, axis=2) if do_backprop: ct_slice = jax.lax.dynamic_slice_in_dim(ct, q_loop_idx, q_loop_stride, axis=2) out_slice, partial_ct = forward_and_vjp_slice( query_slice, q_loop_idx, key, value, ct_slice) query_ct_accum = jax.lax.dynamic_update_slice_in_dim( query_ct_accum, partial_ct[0], q_loop_idx, axis=2) # ignore partial_ct[1], which is wrt the loop idx key_ct_accum = key_ct_accum + partial_ct[2] value_ct_accum = value_ct_accum + partial_ct[3] else: out_slice = forward_slice(query_slice, q_loop_idx, key, value) out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum, out_slice, q_loop_idx, axis=2) q_loop_idx = q_loop_idx + q_loop_stride if do_backprop: return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: return (q_loop_idx, out_accum) final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if not do_backprop: return final_vals[1], None else: return final_vals[1], final_vals[2:]
def forward_and_backward(self, inputs, ct, rng=None, **kwargs): del kwargs output, (qk_ct, v_ct) = self.batch_call_and_or_grad( inputs[0], inputs[2], ct=ct, rng=rng) return output, (qk_ct, np.zeros_like(inputs[1]), v_ct)
def init(self, params): m = np.zeros_like(params) v = np.zeros_like(params) return m, v
def init(self, params): return np.zeros_like(params)
def drop_for_hash(self, x, rng): rate = self._drop_for_hash_rate if self._mode == 'train' and rate > 0.0: keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) return np.where(keep, x / (1.0 - rate), np.zeros_like(x)) return x
def init(self, x): return np.zeros_like(x)
def init(self, params): vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape] return (np.zeros_like(params), vs)
def Relu(x, **unused_kwargs): return np.maximum(x, np.zeros_like(x))