def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun, epsilon, normalize_advantages): """PPO Objective.""" # dist_inputs of the shape float32[128,1,18] # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape float32[128,1,1] # rewards of the shape int32[128,1,1] # actions of the shape int32[128,1] # and old_log_probs of the shape float32[128,1] returns = returns.squeeze(axis=2) values = values.squeeze(axis=2) dones = dones.squeeze(axis=2) rewards = rewards.squeeze(axis=2) assert rewards.shape == dones.shape, ( f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') assert dones.shape == values.shape, ( f'dones.shape was {dones.shape} and values.shape was {values.shape}') assert returns.shape == values.shape, ( f'returns.shape was {returns.shape} and values.shape was {values.shape}' ) assert returns.shape == old_log_probs.shape, ( f'returns.shape was {returns.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun) assert probs_ratio.shape == old_log_probs.shape, ( f'probs_ratio.shape was {probs_ratio.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') # jaxified versions of # returns[dones] = rewards[dones] # values[dones] = 0 returns = jnp.where(dones, rewards, returns) values = jnp.where(dones, jnp.zeros_like(values), values) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 assert old_log_probs.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was ' f'{advantages.shape}') unclipped_objective = UnclippedObjective(probs_ratio, advantages) assert unclipped_objective.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and' f'unclipped_objective.shape was {unclipped_objective.shape}') clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon) assert clipped_objective.shape == advantages.shape, ( f'clipped_objective.shape was {clipped_objective.shape} and' f'advantages.shape was {advantages.shape}') ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) assert ppo_objective.shape == advantages.shape, ( f'ppo_objective.shape was {ppo_objective.shape} and' f'advantages.shape was {advantages.shape}') return ppo_objective
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun, normalize_advantages): """Definition of the Advantage Actor Critic (A2C) loss.""" # dist_inputs of the shape float32[128,1,18] # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape int32[128,1,1] # actions of the shape int32[128,1] # and mask of the shape float32[128,1] # We have to squeeze values and returns, because we # are planning to compute (return - values) * new_log_probs * mask # and all of them should be of the same dimension values = values.squeeze(axis=2) returns = returns.squeeze(axis=2) dones = dones.squeeze(axis=2) rewards = rewards.squeeze(axis=2) assert rewards.shape == dones.shape, ( f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') assert dones.shape == values.shape, ( f'dones.shape was {dones.shape} and values.shape was {values.shape}') assert returns.shape == values.shape, ( f'returns.shape was {returns.shape} and values.shape was {values.shape}' ) assert values.shape == mask.shape, ( f'values.shape was {values.shape} and mask.shape was {mask.shape}') assert returns.shape[0] == dist_inputs.shape[0], ( f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was ' f'{dist_inputs.shape[0]}') new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) assert new_log_probs.shape == mask.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was ' f'{mask.shape}') # jaxified versions of # returns[dones] = rewards[dones] # values[dones] = 0 returns = jnp.where(dones, rewards, returns) values = jnp.where(dones, jnp.zeros_like(values), values) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 assert new_log_probs.shape == advantages.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was ' f'{advantages.shape}') # One of the motivation to the squeezes and assertions is to # avoid [128,1] * [128,1,1] * [128] multiplications in the definition # of the a2c objective - we insist on the same shapes a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask) return a2c_objective
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `(queries, keys)`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or eval'. Dropout applies only in 'train' mode. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) return out
def Relu(): r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function. .. math:: f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right. """ return Fn('Relu', lambda x: np.where(x <= 0, np.zeros_like(x), x))
def _update_diagonal(self, grads, weights, m, v, opt_params): learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] v[0] += grads * grads preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]), np.zeros_like(v[0])) preconditioned_grads = preconditioner * grads m = (1 - momentum) * preconditioned_grads + momentum * m weights = weights - (learning_rate * m).astype(weights.dtype) return weights, (m, v)
def forward(self, x, weights): """Execute dropout.""" if self._mode != 'train': return x state, rng = self.state, self.rng rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] keep = math.random.bernoulli(rng, 1.0 - rate, x.shape) return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x))
def LeakyRelu(a=0.01): r"""Returns a ReLU-like layer with linear nonzero outputs for negative inputs. .. math:: f(x) = \left\{ \begin{array}{cl} ax & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right. Args: a: Slope of line for negative inputs. """ return Fn('LeakyRelu', lambda x: np.where(x >= 0, x, a * x))
def forward_with_state(self, x, weights, state, rng): """Execute dropout.""" if self._mode != 'train': return x, state rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] if rng is None: msg = ('Dropout layer requires apply_fn to be called with a rng keyword ' 'argument. That is, instead of `Dropout(weights, inputs)`, call ' 'it like `Dropout(weights, inputs, rng=key)`.') raise ValueError(msg) keep = math.random.bernoulli(rng, 1.0 - rate, x.shape) return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x)), state
def Selu(alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946): r"""Returns an `Elu`-like layer with an additional scaling/slope parameter. .. math:: f(x) = \left\{ \begin{array}{cl} \lambda \cdot \alpha \cdot (e^x - 1) & \text{if}\ x \leq 0, \\ \lambda \cdot x & \text{otherwise}. \end{array} \right. Args: alpha: Coefficient multiplying the exponential, for negative inputs. lmbda: Coefficient scaling the whole function. """ return Fn('Selu', lambda x: lmbda * np.where(x > 0, x, alpha * np.expm1(x)))
def Elu(a=1.): r"""Returns a ReLU-like layer with exponential outputs for negative inputs. .. math:: f(x) = \left\{ \begin{array}{cl} a \cdot (e^x - 1) & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right. (Asymptotically, :math:`f(x)\rightarrow -a` as :math:`x\rightarrow - \infty`.) Args: a: Coefficient multiplying the exponential, for negative inputs. """ return Fn('Elu', lambda x: np.where(x > 0, x, a * np.expm1(x)))
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of activations. Returns: Tensor of same shape and dtype as the input. """ if self._mode != 'train': return x state, rng = self.state, self.rng rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] keep = math.random.bernoulli(rng, 1.0 - rate, x.shape) return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x))
def tree_update(self, step, grad_tree, weight_tree, slots, opt_params): """Assembles node-local weight and slot updates for the full layer tree.""" grads_flat = _tree_flatten(grad_tree) if self._clip_grad_norm is not None: max_norm = self._clip_grad_norm norm = np.sqrt(sum(np.vdot(x, x) for x in grads_flat)) grads_flat = [ np.where(norm < max_norm, g, g * (max_norm / norm)) for g in grads_flat ] weights_flat = _tree_flatten(weight_tree) updated_pairs = [ self._update_and_check(step, grad, weight, slot, opt_params) for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) ] new_weights_flat, self.slots = zip(*updated_pairs) new_weights, _ = _tree_unflatten(new_weights_flat, weight_tree) return new_weights, self.slots
def forward_unbatched(self, x, mask=None, *, weights, state, update_state): del update_state if self.share_qk: w_q, w_v, w_o = weights else: w_q, w_k, w_v, w_o = weights q = np.matmul(x, w_q) k = None if not self.share_qk: k = np.matmul(x, w_k) v = np.matmul(x, w_v) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=self.share_qk, masked=self.masked) q_info = kv_info = jax.lax.tie_in(x, np.arange(q.shape[-2])) assert (mask is not None) == self.masked if self.masked: # mask is a boolean array (True means "is valid token") ones_like_mask = jax.lax.tie_in(x, np.ones_like(mask, dtype=np.int32)) kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask) o, _ = attend( q, k, v, q_chunk_len=self.chunk_len, kv_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, dropout=self.attention_dropout, rng=None, # TODO(kitaev): support RNG ) out = np.matmul(o, w_o) return out, state
def tree_update(self, step, grad_tree, weight_tree, slots, opt_params): """Assembles node-local weight and slot updates for the full layer tree.""" grads_flat = _tree_flatten(grad_tree) grads_norm = self._l2_norm(grads_flat) if self._clip_grad_norm is not None: max_norm = self._clip_grad_norm grads_flat = [np.where(grads_norm < max_norm, # pylint: disable=g-complex-comprehension g, g * (max_norm / grads_norm)) for g in grads_flat] weights_flat = _tree_flatten(weight_tree) weights_norm = self._l2_norm(weights_flat) updated_pairs = [ self._update_and_check(step, grad, weight, slot, opt_params) for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) ] new_weights_flat, self.slots = zip(*updated_pairs) new_weights, _ = _tree_unflatten(new_weights_flat, weight_tree) metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm} return new_weights, self.slots, metrics
def _update_sketched(self, grads, weights, m, v, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] shape = weights.shape rank = len(shape) reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = np.where(current_accumulator > 0.0, 1.0 / np.sqrt(current_accumulator), np.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m weights = weights - (learning_rate * m).astype(weights.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = np.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return weights, (m, v)
def tree_update(self, step, grad_tree, weight_tree, slots, opt_params): """Assembles node-local weight and slot updates for the full layer tree. Args: step: Current step number in the training process. grad_tree: Gradients for the entire model, in a tree that matches the model's layer structure. weight_tree: Current weights for the entire model, in a tree that matches the model's layer structure. slots: Optimizer slots. opt_params: Optimizer hyperparameters (e.g. learning rate, momentum). Returns: Tuple `(weights, slots)`, where `weights` are the optimizer-updated weights for the whole model (in a tree matching the model's layer structure) and `slots` are the updated optimizer slot values. """ grads_flat = math.tree_flatten(grad_tree) grads_norm = self._l2_norm(grads_flat) if self._clip_grad_norm is not None: max_norm = self._clip_grad_norm grads_flat = [ np.where( grads_norm < max_norm, # pylint: disable=g-complex-comprehension g, g * (max_norm / grads_norm)) for g in grads_flat ] weights_flat = math.tree_flatten(weight_tree) weights_norm = self._l2_norm(weights_flat) updated_pairs = [ self._update_and_check(step, grad, weight, slot, opt_params) for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) ] new_weights_flat, self.slots = zip(*updated_pairs) new_weights, _ = math.tree_unflatten(new_weights_flat, weight_tree) metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm} return new_weights, self.slots, metrics
def clip_grads(grad_tree, max_norm): """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`.""" norm = l2_norm(grad_tree) normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm)) return layers.nested_map(grad_tree, normalize)
def Elu(a=1.): return Fn('Elu', lambda x: np.where(x > 0, x, a * np.expm1(x)))
def LeakyRelu(a=0.01): return Fn('LeakyRelu', lambda x: np.where(x >= 0, x, a * x))
def Elu(x, a=1., **unused_kwargs): return np.where(x > 0, x, a * np.expm1(x))
def LeakyRelu(x, a=0.01, **unused_kwargs): return np.where(x >= 0, x, a * x)
def LeakyRelu(x, a=0.01): return np.where(x >= 0, x, a * x)
def Selu(x, alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946): return lmbda * np.where(x > 0, x, alpha * np.expm1(x))
def Elu(x, a=1.): return np.where(x > 0, x, a * np.expm1(x))