def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if backend.get_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def _update_diagonal(self, grads, params, m, v, opt_params): learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] v[0] += grads * grads preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]), np.zeros_like(v[0])) preconditioned_grads = preconditioner * grads m = (1 - momentum) * preconditioned_grads + momentum * m params = params - (learning_rate * m).astype(params.dtype) return params, (m, v)
def _forward_predict(self, inputs, state, rng): if not self._share_qk: state = _fast_inference_update_state(inputs, state) (q, _, _) = inputs (ks, vs, mask, index) = state else: mask_excluding_attention_in_place = state[2] (q, _, v) = inputs k = self.make_unit_length(q) state = _fast_inference_update_state((q, k, v), state) (ks, vs, mask, index) = state # Only the initial position in a sequence may attend to itself. mask = np.where(index > 1, mask_excluding_attention_in_place, mask) output = attention.DotProductAttention(q, ks, vs, mask, dropout=self.dropout, mode=self._mode, rng=rng) def roll_state(state): """Rolls the buffers backward to make space for new data.""" (ks, vs, mask, index) = state # Move the second bin into the first one's place in both buffers. def roll_buffer(buf): return jax.ops.index_update( buf, jax.ops.index[:, :self.bin_length, :], buf[:, self.bin_length:, :], ) (ks, vs) = map(roll_buffer, (ks, vs)) # Zero out the second bin in the mask. mask = jax.ops.index_update(mask, jax.ops.index[:, :, self.bin_length:], 0) # Update the index to match the rolled buffers. index -= self.bin_length return (ks, vs, mask, index) # Once we get to the end of the buffer, move the second bin back to make # space for new data: [ bin_i bin_{i+1} | ] -> [ bin_{i+1} | bin_{i+1} ], # where | is where index points at in the buffer. state = jax.lax.cond( pred=(index == 2 * self.bin_length), true_operand=state, true_fun=roll_state, false_operand=state, false_fun=(lambda x: x), ) return (output, state)
def forward_with_state(self, x, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): """Execute dropout.""" del kwargs if self._mode != 'train': return x, state rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] if rng is None: msg = ('Dropout layer requires apply_fn to be called with a rng keyword ' 'argument. That is, instead of `Dropout(weights, inputs)`, call ' 'it like `Dropout(weights, inputs, rng=key)`.') raise ValueError(msg) keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) return np.where(keep, x / (1.0 - rate), np.zeros_like(x)), state
def _update_sketched(self, grads, params, m, v, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] shape = params.shape rank = len(shape) reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m params = params - (learning_rate * m).astype(params.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return params, (m, v)
def drop_for_hash(self, x, rng): rate = self._drop_for_hash_rate if self._mode == 'train' and rate > 0.0: keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) return np.where(keep, x / (1.0 - rate), np.zeros_like(x)) return x
def Selu(x, alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946): return lmbda * np.where(x > 0, x, alpha * np.expm1(x))
def Elu(x, a=1., **unused_kwargs): return np.where(x > 0, x, a * np.expm1(x))
def LeakyRelu(x, a=0.01, **unused_kwargs): return np.where(x >= 0, x, a * x)
def clip_grads(grad_tree, max_norm): """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`.""" norm = l2_norm(grad_tree) normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm)) return layers.nested_map(grad_tree, normalize)