def f(log_probs, advantages, old_log_probs, mask): if reweight: # Use new policy weights for sampled actions instead. mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs) if sampled_all_discrete: # Actions were sampled uniformly; weight them. mask *= jnp.exp(old_log_probs) weights = jnp.minimum(awr_weights(advantages, beta), w_max) return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
def PPOJointLoss(x, **unused_kwargs): """Definition of the Proximal Policy Optimization loss.""" dist_inputs, values, returns, actions, old_log_probs, mask = x del mask # TODO(lukaszkaiser): make PPO work with Transformer new_log_probs = self._policy_dist.log_prob(dist_inputs, actions) advantages = returns - values l2_value_loss = jnp.sum(advantages**2) * self._value_loss_coeff # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) entropy_loss = self._policy_dist.entropy(new_log_probs) *\ self._entropy_coeff return -ppo_objective.mean() + l2_value_loss - entropy_loss
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def Softmax5Branches(x_list, **unused_kwargs): """Softmax qs. The input xs is a list of weights and embedded queries of the form w_1 ... w_n q_1 ... q_n. The q_1 ... q_n will be kept, result appended. Args: x_list: the input weights and embeddings. Returns: the weighted average of q_1 ... q_n according to softmax(w). """ n_branches = 5 softmax_activations = x_list[:n_branches] max_sa = softmax_activations[0] for x in softmax_activations: max_sa = np.maximum(max_sa, x) softmax_activations = [x - max_sa for x in softmax_activations] softmax_activations = [np.exp(x) for x in softmax_activations] sum_sa = sum(softmax_activations) softmax_activations = [x / sum_sa for x in softmax_activations] res = sum([ x_list[i + n_branches] * softmax_activations[i] for i in range(n_branches) ]) return res
def _aggregate_values(self, values, aggregate_max, act_log_probs): if self._q_value: if aggregate_max: values = jnp.max(values, axis=1) elif self._sample_all_discrete_actions: values = jnp.sum(values * jnp.exp(act_log_probs), axis=1) else: values = jnp.mean(values, axis=1) return np.array(values) # Move the values to CPU.
def _calc_adv_weights(self, adv, valid_mask): weights = jnp.exp(adv / self._temperature) valid_weights = weights[valid_mask] weights_mean = jnp.mean(valid_weights) weights_min = jnp.min(valid_weights) weights_max = jnp.max(valid_weights) weights = jnp.minimum(weights, self._weight_clip) return weights, weights_mean, weights_min, weights_max
def ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun): """Probability Ratio from the PPO algorithm.""" # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) return probs_ratio
def AWRLoss(x, **unused_kwargs): # pylint: disable=invalid-name logps, values, returns, actions = x advantage = returns - values l2_value_loss = jnp.sum( (returns - values)**2) * self._value_loss_coeff awr_weights = jnp.minimum(jnp.exp(advantage / self._beta), self._w_max) log_loss = -1.0 * self._policy_dist.log_prob(logps, actions) policy_loss = jnp.sum( log_loss * awr_weights) / jnp.sum(awr_weights) return policy_loss + l2_value_loss
def Softmax(axis=-1): """Returns a layer that applies softmax along one tensor axis. `Softmax` acts on a group of values and normalizes them to look like a set of probability values. (Probability values must be non-negative, and as a set must sum to 1.) Args: axis: Axis along which values are grouped for computing softmax. """ return Fn('Softmax', lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
def fn(dist_inputs, actions, q_values, act_log_probs, mask): del dist_inputs, actions, mask q_values = jnp.swapaxes(q_values, 0, 1) act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) if self._sample_all_discrete_actions: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) else: values = jnp.mean(q_values, axis=0) advantages = q_values - values # Broadcasting values over n_samples if preprocess: advantages = self._preprocess_advantages(advantages) return advantages
def f(new_log_probs, advantages, old_log_probs, mask): # Old log probs have an undesirable extra dimension which we remove here old_log_probs = old_log_probs.squeeze(axis=-1) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, 1 + epsilon) * advantages ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) return -np.sum(ppo_objective * mask) / np.sum(mask)
def ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun): """Probability Ratio from the PPO algorithm.""" # dist_inputs of the shape float32[128,1,18] # actions of the shape int32[128,1] # and old_log_probs of the shape float32[128,1] new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) assert new_log_probs.shape == old_log_probs.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) return probs_ratio
def PPOLoss(x, epsilon, **unused_kwargs): """Definition of the Proximal Policy Optimization loss.""" (new_log_probs, advantages, old_log_probs, mask) = x # Old log probs have an undesirable extra dimension which we remove here old_log_probs = old_log_probs.squeeze(axis=-1) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, 1 + epsilon) * advantages ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) return -np.sum(ppo_objective * mask) / np.sum(mask)
def ProbsRatioMean(x, **unused_kwargs): """Probability Ratio Mean from the PPO algorithm.""" dist_inputs, _, _, actions, old_log_probs = x new_log_probs = self._policy_dist.log_prob(dist_inputs, actions) # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) return jnp.mean(probs_ratio)
def f(new_log_probs, advantages, old_log_probs, mask): # new_log_probs of the shape float32[128,1] # advantages of the shape int32[128,1] # old_log_probs of the shape int32[128,1] # mask of the shape int32[128,1] if new_log_probs.shape != advantages.shape: raise ValueError('New log-probs and advantages shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, advantages.shape)) if new_log_probs.shape != old_log_probs.shape: raise ValueError('New log-probs and old log-probs shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, old_log_probs.shape)) if new_log_probs.shape != mask.shape: raise ValueError('New log-probs and mask shapes should be the same' ', %s != %s' % (new_log_probs.shape, mask.shape)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) if advantages.shape != probs_ratio.shape: raise ValueError('New log-probs and old log probs shapes ' 'should be the same, %s != %s' % (advantages.shape, probs_ratio.shape)) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages if unclipped_objective.shape != probs_ratio.shape: raise ValueError('unclipped_objective and clipped_objective shapes ' 'should be the same, %s != %s' % ( unclipped_objective.shape, clipped_objective.shape)) ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) if ppo_objective.shape != mask.shape: raise ValueError('ppo_objective and mask shapes ' 'should be the same, %s != %s' % ( ppo_objective.shape, mask.shape)) ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask) entropy_vec = self._policy_dist.entropy( new_log_probs) * self._entropy_coeff entropy_loss = jnp.mean(entropy_vec) combined_loss = ppo_loss - entropy_loss return combined_loss
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `(queries, keys)`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or eval'. Dropout applies only in 'train' mode. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) return out
def LossInput(dist_inputs, actions, q_values, act_log_probs, mask): # pylint: disable=invalid-name """Calculates action log probabilities and normalizes advantages.""" # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...) q_values = jnp.swapaxes(q_values, 0, 1) mask = jnp.swapaxes(mask, 0, 1) actions = jnp.swapaxes(actions, 0, 1) act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting? if self._sample_all_discrete_actions: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) else: values = jnp.mean(q_values, axis=0) advantages = q_values - values # Broadcasting values over n_samples advantages = self._preprocess_advantages(advantages) # Broadcast inputs and calculate log-probs dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) log_probs = self._policy_dist.log_prob(dist_inputs, actions) return (log_probs, advantages, act_log_probs, mask)
def SRU(n_units, activation=None, rescale=False, highway_bias=0): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. rescale: To offset the problem of the gradient vanishing in the h_t as a result of light recurrence and highway computation for deeper layers, a scaling correction alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755, page 4, section 3.2 Initialization. highway_bias: intial bias of highway gates Returns: The SRU layer. """ # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r) * ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
def Softmax(axis=-1): """Layer that applies softmax: exponentiate and normalize along given axis.""" return Fn('Softmax', lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
def Exp(): return Fn('Exp', lambda x: jnp.exp(x)) # pylint: disable=unnecessary-lambda
def forward_unbatched(self, x, *, weights, state, update_state): w_q, w_v, w_o = weights q = np.matmul(x, w_q) v = np.matmul(x, w_v) if update_state: _, old_rng = state rng = jax.random.fold_in(old_rng, 0) hash_rng = jax.random.fold_in(rng, 1) buckets = self.hash_vectors(q, hash_rng) state = (buckets, rng) else: buckets, rng = state rng = jax.random.fold_in(rng, 2) seqlen = x.shape[0] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sq = np.take(q, st, axis=0) sv = np.take(v, st, axis=0) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=True) q_info = st so, slogits = attend( sq, k=None, v=sv, q_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, dropout=self.attention_dropout, rng=rng, ) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes > 1: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True)) o = np.sum(o * probs, axis=0) assert o.shape == (seqlen, w_v.shape[-1]) out = np.matmul(o, w_o) return out, state
def attend( q, k=None, v=None, q_chunk_len=None, kv_chunk_len=None, n_chunks_before=0, n_chunks_after=0, mask_fn=None, q_info=None, kv_info=None, dropout=0.0, rng=None, ): """Dot-product attention, with optional chunking and/or masking. Args: q: Query vectors, shape [q_len, d_qk] k: Key vectors, shape [kv_len, d_qk]; or None v: Value vectors, shape [kv_len, d_v] q_chunk_len: Set to non-zero to enable chunking for query vectors kv_chunk_len: Set to non-zero to enable chunking for key/value vectors n_chunks_before: Number of adjacent previous chunks to attend to n_chunks_after: Number of adjacent subsequent chunks to attend to mask_fn: TODO(kitaev) doc q_info: Query-associated metadata for masking kv_info: Key-associated metadata for masking dropout: Dropout rate rng: RNG for dropout Returns: A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and dots_logsumexp has shape [q_len]. The logsumexp of the attention probabilities is useful for combining multiple rounds of attention (as in LSH attention). """ assert v is not None share_qk = (k is None) if q_info is None: q_info = np.arange(q.shape[-2]) if kv_info is None and not share_qk: kv_info = np.arange(v.shape[-2]) # Split q/k/v into chunks along the time axis, if desired. if q_chunk_len is not None: q = np.reshape(q, (-1, q_chunk_len, q.shape[-1])) q_info = np.reshape(q_info, (-1, q_chunk_len)) if share_qk: assert kv_chunk_len is None or kv_chunk_len == q_chunk_len k = q kv_chunk_len = q_chunk_len kv_info = q_info elif kv_chunk_len is not None: k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1])) kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) if kv_chunk_len is not None: v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1])) if share_qk: k = length_normalized(k) k = k / np.sqrt(k.shape[-1]) # Optionally include adjacent chunks. if q_chunk_len is not None or kv_chunk_len is not None: assert q_chunk_len is not None and kv_chunk_len is not None else: assert n_chunks_before == 0 and n_chunks_after == 0 k = look_adjacent(k, n_chunks_before, n_chunks_after) v = look_adjacent(v, n_chunks_before, n_chunks_after) kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after) # Dot-product attention. dots = np.matmul(q, np.swapaxes(k, -1, -2)) # Masking if mask_fn is not None: dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :]) # Softmax. dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if dropout > 0.0: assert rng is not None # Dropout is broadcast across the bin dimension dropout_shape = (dots.shape[-2], dots.shape[-1]) # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix) keep_prob = jax.lax.tie_in(dots, 1.0 - dropout) keep = jax.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob) dots = dots * multiplier # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn. out = np.matmul(dots, v) out = np.reshape(out, (-1, out.shape[-1])) dots_logsumexp = np.reshape(dots_logsumexp, (-1, )) return out, dots_logsumexp
def Softmax(x, axis=-1, **unused_kwargs): """Apply softmax to x: exponentiate and normalize along the given axis.""" return np.exp(x - math.logsumexp(x, axis, keepdims=True))
def Exp(x, **unused_kwargs): return np.exp(x)
def f(log_probs, advantages, old_log_probs, mask): if reweight: # Use new policy weights for sampled actions instead. mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs) weights = jnp.minimum(awr_weights(advantages, beta), w_max) return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
def awr_weights(advantages, beta): return jnp.exp(advantages / beta)
def entropy(self, log_probs): probs = jnp.exp(log_probs) return -jnp.sum(probs * log_probs, axis=-1)
def Exp(): """Returns a layer that computes the element-wise exponential of a tensor.""" return Fn('Exp', lambda x: jnp.exp(x)) # pylint: disable=unnecessary-lambda
def entropy(self, log_probs): del log_probs # would be helpful if self._std was learnable return jnp.exp(self._std) + .5 * jnp.log(2.0 * jnp.pi * jnp.e)
def AWRLoss(x, beta, w_max, **unused_kwargs): """Definition of the Advantage Weighted Regression (AWR) loss.""" (log_probs, advantages, _) = x weights = jnp.minimum(jnp.exp(advantages / beta), w_max) return -(log_probs * weights).mean()