def init(self, params): shape = params.shape slots = [] if self._factored and len(shape) >= 2: v_row = np.zeros(shape[:-1], dtype=np.float32) v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32) slots.extend([v_row, v_col]) else: v = np.zeros_like(params) slots.append(v) if self._do_momentum: m = np.zeros_like(params) slots.append(m) return slots
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if backend.get_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def _update_diagonal(self, grads, params, m, v, opt_params): learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] v[0] += grads * grads preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]), np.zeros_like(v[0])) preconditioned_grads = preconditioner * grads m = (1 - momentum) * preconditioned_grads + momentum * m params = params - (learning_rate * m).astype(params.dtype) return params, (m, v)
def forward_and_backward(self, inputs, ct, state=base.EMPTY_STATE, new_state=base.EMPTY_STATE, rng=None, **kwargs): del kwargs output, _, (qk_ct, v_ct) = self.batch_call_and_or_grad(inputs[0], inputs[2], ct=ct, new_state=new_state, rng=rng) return output, (qk_ct, np.zeros_like(inputs[1]), v_ct)
def forward_with_state(self, x, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): """Execute dropout.""" del kwargs if self._mode != 'train': return x, state rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] if rng is None: msg = ('Dropout layer requires apply_fn to be called with a rng keyword ' 'argument. That is, instead of `Dropout(weights, inputs)`, call ' 'it like `Dropout(weights, inputs, rng=key)`.') raise ValueError(msg) keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) return np.where(keep, x / (1.0 - rate), np.zeros_like(x)), state
def backward(self, inputs, output, ct, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, new_state=base.EMPTY_STATE, rng=None, **kwargs): del output, weights, state _, _, (qk_ct, v_ct) = self.batch_call_and_or_grad(inputs[0], inputs[2], return_output=False, ct=ct, new_state=new_state, rng=rng) inputs_ct = (qk_ct, np.zeros_like(inputs[1]), v_ct) return inputs_ct, ()
def _update_sketched(self, grads, params, m, v, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] shape = params.shape rank = len(shape) reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m params = params - (learning_rate * m).astype(params.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return params, (m, v)
def forward_and_backward(self, inputs, ct, state=base.EMPTY_STATE, new_state=base.EMPTY_STATE, rng=None, **kwargs): del state, new_state, kwargs query, key, value = inputs depth = np.shape(query)[-1] do_backprop = ct is not None # jax uses the term cotangent (ct) to refer to gradient signals, and # vector-Jacobian product (vjp) for back-propagation through a layer. def make_mask(N, M, k): # pylint: disable=invalid-name """Constructs a slice of the causal attention mask. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32)) y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32)) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def make_self_mask(N, M, k): # pylint: disable=invalid-name """Masks out elements attending to self. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32)) y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32)) mask = jax.lax.eq((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name """Forward pass for a subset of the query vectors.""" if self._share_qk: key = self.make_unit_length(key) dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. if self._share_qk: self_mask = make_self_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e5 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout is not None and self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) slice_rng = jax.random.fold_in(rng, q_loop_idx) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Re-normalize. dots /= dots_sum # Re-normalize. out_slice = np.matmul(dots, value) return out_slice def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): # pylint: disable=invalid-name # Capture q_loop_idx to avoid calculated gradients wrt. it. def forward_slice_with_q_loop_idx(query_slice, key, value): # pylint: disable=invalid-name return forward_slice(query_slice, q_loop_idx, key, value) output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx, query_slice, key, value) return output_slice, vjpfun(ct_slice) q_loop_idx = np.zeros((), dtype=np.int32) q_loop_max = query.shape[-2] q_loop_stride = self._loop_stride if q_loop_max == 1: # For abstract runs with unknown shapes. q_loop_stride = 1 assert q_loop_max % q_loop_stride == 0, ( 'Stride must evenly divide the number of query elements.') out_accum = np.zeros_like(query) if do_backprop: query_ct_accum = np.zeros_like(query) key_ct_accum = np.zeros_like(key) value_ct_accum = np.zeros_like(value) init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: init_vals = (q_loop_idx, out_accum) def cond_fun(vals): # pylint: disable=invalid-name q_loop_idx = vals[0] return jax.lax.lt(q_loop_idx, q_loop_max) def body_fun(vals): # pylint: disable=invalid-name """Compute a slice of the attention mechanism.""" if do_backprop: (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) = vals else: q_loop_idx, out_accum = vals query_slice = jax.lax.dynamic_slice_in_dim(query, q_loop_idx, q_loop_stride, axis=-2) if do_backprop: ct_slice = jax.lax.dynamic_slice_in_dim(ct, q_loop_idx, q_loop_stride, axis=-2) out_slice, partial_ct = forward_and_vjp_slice( query_slice, q_loop_idx, key, value, ct_slice) query_ct_accum = jax.lax.dynamic_update_slice_in_dim( query_ct_accum, partial_ct[0], q_loop_idx, axis=-2) key_ct_accum = key_ct_accum + partial_ct[1] value_ct_accum = value_ct_accum + partial_ct[2] else: out_slice = forward_slice(query_slice, q_loop_idx, key, value) out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum, out_slice, q_loop_idx, axis=-2) q_loop_idx = q_loop_idx + q_loop_stride if do_backprop: return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: return (q_loop_idx, out_accum) final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if not do_backprop: return final_vals[1], None else: return final_vals[1], final_vals[2:]
def drop_for_hash(self, x, rng): rate = self._drop_for_hash_rate if self._mode == 'train' and rate > 0.0: keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) return np.where(keep, x / (1.0 - rate), np.zeros_like(x)) return x
def batch_call_and_or_grad(self, qk, v, ct=None, return_output=True, new_state=None, return_state=False, rng=None): assert return_output or ct is not None, 'No work to perform!' if new_state is not None and new_state is not base.EMPTY_STATE: buckets = new_state else: buckets = None # The approach here is to perform attention for one batch element and head # at a time. Note that there is absolutely no interaction across examples or # heads: this layer has no parameters, and hashing patterns are also # different across examples/heads. As a result, batching doesn't give any # performance gains except in the case of accelerator under-utilization. We # assume that hash-based attention will be applied primarily to long # sequences, where unbatched attention for a single head has sufficient # computation to fill up the accelerator. batch_loop_idx = np.zeros((), dtype=np.int32) batch_loop_max = qk.shape[0] init_vals = (batch_loop_idx, ) if return_output: out_accum = np.zeros_like(qk) init_vals = init_vals + (out_accum, ) if return_state: buckets_accum = np.zeros( [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32) init_vals = init_vals + (buckets_accum, ) if ct is not None: qk_ct_accum = np.zeros_like(qk) v_ct_accum = np.zeros_like(v) init_vals = init_vals + (qk_ct_accum, v_ct_accum) def cond_fun(vals): batch_loop_idx = vals[0] return jax.lax.lt(batch_loop_idx, batch_loop_max) def body_fun(vals): """Performs attention for a single batch element and head.""" batch_loop_idx = vals[0] if self._prng is None: hash_slice_rng = jax.random.fold_in(rng, batch_loop_idx) hash_rng, slice_rng = backend.random.split(hash_slice_rng) else: # TODO(kitaev): Maybe use the same RNG across examples (but not heads)? hash_rng = jax.random.fold_in(self._prng, batch_loop_idx) slice_rng = jax.random.fold_in(rng, batch_loop_idx) qk_slice = jax.lax.dynamic_index_in_dim(qk, batch_loop_idx, axis=0, keepdims=False) v_slice = jax.lax.dynamic_index_in_dim(v, batch_loop_idx, axis=0, keepdims=False) if buckets is None: buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng) else: buckets_slice = jax.lax.dynamic_index_in_dim(buckets, batch_loop_idx, axis=0, keepdims=False) if ct is None: out_slice = self.single_call(qk_slice, v_slice, buckets_slice, rng=slice_rng) else: def _do_single_call(qk_slice, v_slice): return self.single_call(qk_slice, v_slice, buckets_slice, rng=slice_rng) ct_slice = jax.lax.dynamic_index_in_dim(ct, batch_loop_idx, axis=0, keepdims=False) out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice) qk_ct_slice, v_ct_slice = vjpfun(ct_slice) new_vals = (batch_loop_idx + 1, ) if return_output: out_accum = vals[1] out_accum = jax.lax.dynamic_update_index_in_dim(out_accum, out_slice, batch_loop_idx, axis=0) new_vals = new_vals + (out_accum, ) if return_state: buckets_accum = vals[2] buckets_accum = jax.lax.dynamic_update_index_in_dim( buckets_accum, buckets_slice, batch_loop_idx, axis=0) new_vals = new_vals + (buckets_accum, ) if ct is not None: qk_ct_accum, v_ct_accum = vals[-2:] qk_ct_accum = jax.lax.dynamic_update_index_in_dim( qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0) v_ct_accum = jax.lax.dynamic_update_index_in_dim( v_ct_accum, v_ct_slice, batch_loop_idx, axis=0) new_vals = new_vals + (qk_ct_accum, v_ct_accum) return new_vals final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if return_output: out = final_vals[1] else: out = None if return_state: state = final_vals[2] else: state = None if ct is not None: input_ct = final_vals[-2:] else: input_ct = None return out, state, input_ct
def ParametricRelu(x, a=1., **unused_kwargs): return np.maximum(a * x, np.zeros_like(x))
def Relu(x, **unused_kwargs): return np.maximum(x, np.zeros_like(x))
def init(self, params): vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape] return (np.zeros_like(params), vs)
def init(self, params): m = np.zeros_like(params) v = np.zeros_like(params) return m, v
def init(self, params): return np.zeros_like(params)