def _update_diagonal(self, g, w, m, v1, v2, opt_params): learning_rate = opt_params['learning_rate'] beta2 = opt_params['second_moment_averaging'] weight_decay = opt_params['weight_decay'] is_beta2_1 = (beta2 == 1).astype(g.dtype) one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) v1[0] = beta2 * v1[0] + one_minus_beta2_except1 * g * g preconditioner = jnp.where(v1[0] > 0, 1.0 / (jnp.sqrt(v1[0]) + 1e-16), jnp.zeros_like(v1[0])) pg = preconditioner * g if self._graft: v2[0] += g * g preconditioner_graft = jnp.where( v2[0] > 0, 1.0 / (jnp.sqrt(v2[0]) + 1e-16), jnp.zeros_like(v2[0])) pg_graft = preconditioner_graft * g pg_norm = jnp.linalg.norm(pg) pg_graft_norm = jnp.linalg.norm(pg_graft) pg = pg * (pg_graft_norm/(pg_norm + 1e-16)) pg = pg + w * weight_decay if self._has_momentum: m, update = self._momentum_update(pg, m, opt_params['momentum']) else: update = pg w = w - (update * learning_rate).astype(w.dtype) return w, (m, v1, v2)
def _update_sketched(self, g, w, m, v1, v2, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] beta2 = opt_params['second_moment_averaging'] weight_decay = opt_params['weight_decay'] shape = w.shape rank = len(shape) reshaped_accumulators = [ jnp.reshape(v1[i], self._expanded_shape(shape, i)) for i in range(rank) ] acc = self._minimum(reshaped_accumulators) is_beta2_1 = (beta2 == 1).astype(g.dtype) one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) acc = beta2 * acc + one_minus_beta2_except1 * g * g preconditioner = jnp.where(acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16), jnp.zeros_like(acc)) pg = g * preconditioner if self._graft: v2_acc = self._minimum([ jnp.reshape(v2[i], self._expanded_shape(shape, i)) for i in range(rank) ]) v2_acc = v2_acc + g * g preconditioner_graft = jnp.where(v2_acc > 0.0, 1.0 / (jnp.sqrt(v2_acc) + 1e-16), jnp.zeros_like(v2_acc)) pg_graft = preconditioner_graft * g pg_norm = jnp.linalg.norm(pg) pg_graft_norm = jnp.linalg.norm(pg_graft) pg = pg * (pg_graft_norm / (pg_norm + 1e-16)) pg = pg + w * weight_decay if self._has_momentum: m, update = self._momentum_update(pg, m, momentum) else: update = pg w = w - (learning_rate * update).astype(w.dtype) for i in range(len(v1)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = jnp.amax(acc, axis=axes) v1[i] = dim_accumulator if self._graft: for i in range(len(v2)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = jnp.amax(v2_acc, axis=axes) v2[i] = dim_accumulator return w, (m, v1, v2)
def init(self, weights): shape = weights.shape slots = [] if self._factored and len(shape) >= 2: v_row = jnp.zeros(shape[:-1], dtype=jnp.float32) v_col = jnp.zeros(shape[:-2] + shape[-1:], dtype=jnp.float32) slots.extend([v_row, v_col]) else: v = jnp.zeros_like(weights) slots.append(v) if self._do_momentum: m = jnp.zeros_like(weights) slots.append(m) return slots
def f(preds, values, returns, actions, mask): advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def forward(self, x): rng = self.rng batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = fastmath.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = fastmath.random.randint( rng1, (batch_size,), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = fastmath.random.randint( rng2, (batch_size,), 0, max_pos-length) random_start *= start_from_nonzero positions += random_start[:, None] res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = self.weights[bn][i] pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = fastmath.random.randint( rng3, (batch_size,), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) + jnp.zeros_like(x) return x + res
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun, epsilon, normalize_advantages): """PPO Objective.""" # dist_inputs of the shape float32[128,1,18] # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape float32[128,1,1] # rewards of the shape int32[128,1,1] # actions of the shape int32[128,1] # and old_log_probs of the shape float32[128,1] returns = returns.squeeze(axis=2) values = values.squeeze(axis=2) dones = dones.squeeze(axis=2) rewards = rewards.squeeze(axis=2) assert rewards.shape == dones.shape, ( f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') assert dones.shape == values.shape, ( f'dones.shape was {dones.shape} and values.shape was {values.shape}') assert returns.shape == values.shape, ( f'returns.shape was {returns.shape} and values.shape was {values.shape}' ) assert returns.shape == old_log_probs.shape, ( f'returns.shape was {returns.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun) assert probs_ratio.shape == old_log_probs.shape, ( f'probs_ratio.shape was {probs_ratio.shape} and' f'old_log_probs.shape was {old_log_probs.shape}') # jaxified versions of # returns[dones] = rewards[dones] # values[dones] = 0 returns = jnp.where(dones, rewards, returns) values = jnp.where(dones, jnp.zeros_like(values), values) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 assert old_log_probs.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was ' f'{advantages.shape}') unclipped_objective = UnclippedObjective(probs_ratio, advantages) assert unclipped_objective.shape == advantages.shape, ( f'old_log_probs.shape was {old_log_probs.shape} and' f'unclipped_objective.shape was {unclipped_objective.shape}') clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon) assert clipped_objective.shape == advantages.shape, ( f'clipped_objective.shape was {clipped_objective.shape} and' f'advantages.shape was {advantages.shape}') ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) assert ppo_objective.shape == advantages.shape, ( f'ppo_objective.shape was {ppo_objective.shape} and' f'advantages.shape was {advantages.shape}') return ppo_objective
def init(self, w): momentum = [] if self._has_momentum: momentum = jnp.zeros_like(w) v1s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] v2s = [] if self._graft: v2s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] return (momentum, v1s, v2s)
def _UpdateRow(x): # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,) row_e, row_d, row_mask_e = x # final_row - (L1+L2, H) final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0) # Find the last real token/vector of the encoder. e_idx = jnp.sum(row_mask_e, dtype=jnp.int32) # Starting after that index, update with the decoder row. return jax.lax.dynamic_update_slice(final_row, row_d, (e_idx, 0))
def DotProductAttention(queries, keys, values, pos_emb, context_bias, location_bias, mask, separate_cls, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `queries` and `keys`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. pos_emb: Per-head activations representing positional embeddings. context_bias: Global context bias from Transformer XL's attention. location_bias: Global location bias from Transformer XL's attention. mask: Mask that distinguishes positions with real content vs. padding. separate_cls: True/False if we separate_cls in calculations. dropout: Probabilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values. mode: One of `'train'`, `'eval'`, or `'predict'`. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) ac = jnp.einsum('bnid,bnjd->bnij', queries + context_bias, keys) bd = jnp.einsum('bnid,jnd->bnij', queries + location_bias, pos_emb) if mode != 'predict': bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling) if separate_cls: # Masking out location part of attention for cls token bd = bd.at[:, :, :, 0].set(0) bd = bd.at[:, :, 0, :].set(0) dots = (ac + bd) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def _update_diagonal(self, grads, weights, m, v, opt_params): learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] v[0] += grads * grads preconditioner = jnp.where(v[0] > 0, 1.0 / jnp.sqrt(v[0]), jnp.zeros_like(v[0])) preconditioned_grads = preconditioner * grads m = (1 - momentum) * preconditioned_grads + momentum * m weights = weights - (learning_rate * m).astype(weights.dtype) return weights, (m, v)
def _UpdateRow(x): # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,) row_e, row_d, row_mask_e = x # final_row - (L1+L2, H) final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0) # Find the last real token/vector of the encoder. e_idx = jnp.sum(row_mask_e, dtype=jnp.int32) # Starting after that index, update with the decoder row. zero = jnp.array(0, dtype=e_idx.dtype) # avoid int32/int64 mismatch return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))
def Relu(): r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function. .. math:: f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right. """ return Fn('Relu', lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x))
def _per_head_attention(queries, keys, values, mask, dropout, mode, rng): """Computes new per-head activations via scaled dot-product attention. This function is the core of the attention mechanism. Given per-head ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it: - computes the scaled dot product of each Q-K pair; - applies ``mask`` to screen out positions that come from padding tokens (indicated by 0 value); - [in ``'train'`` mode] applies dropout to Q-K dot products; - computes Q-K attention strengths using a per-query softmax of the Q-K dot products; and - for each query position, combines V vectors according to the Q-K attention strengths. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention strengths. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only in ``'train'`` mode. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. rng: Single-use random number generator (JAX PRNG key). Returns: Tuple of (activations, attn_strengths), where activations are new per-head activation vectors and attn_strengths is a matrix of per-head attention strengths. """ if dropout >= 1.0: raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.') d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) attn_strengths = ( jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))) if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape) attn_strengths = jnp.where(keep, attn_strengths / (1.0 - dropout), jnp.zeros_like(attn_strengths)) activations = jnp.matmul(attn_strengths, values).astype(jnp.float32) attn_strengths = attn_strengths.astype(jnp.float32) return activations, attn_strengths
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun, normalize_advantages): """Definition of the Advantage Actor Critic (A2C) loss.""" # dist_inputs of the shape float32[128,1,18] # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape int32[128,1,1] # actions of the shape int32[128,1] # and mask of the shape float32[128,1] # We have to squeeze values and returns, because we # are planning to compute (return - values) * new_log_probs * mask # and all of them should be of the same dimension values = values.squeeze(axis=2) returns = returns.squeeze(axis=2) dones = dones.squeeze(axis=2) rewards = rewards.squeeze(axis=2) assert rewards.shape == dones.shape, ( f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') assert dones.shape == values.shape, ( f'dones.shape was {dones.shape} and values.shape was {values.shape}') assert returns.shape == values.shape, ( f'returns.shape was {returns.shape} and values.shape was {values.shape}' ) assert values.shape == mask.shape, ( f'values.shape was {values.shape} and mask.shape was {mask.shape}') assert returns.shape[0] == dist_inputs.shape[0], ( f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was ' f'{dist_inputs.shape[0]}') new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) assert new_log_probs.shape == mask.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was ' f'{mask.shape}') # jaxified versions of # returns[dones] = rewards[dones] # values[dones] = 0 returns = jnp.where(dones, rewards, returns) values = jnp.where(dones, jnp.zeros_like(values), values) advantages = returns - values if normalize_advantages: advantages = advantages - jnp.mean(advantages) advantages /= jnp.std(advantages) + 1e-8 assert new_log_probs.shape == advantages.shape, ( f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was ' f'{advantages.shape}') # One of the motivation to the squeezes and assertions is to # avoid [128,1] * [128,1,1] * [128] multiplications in the definition # of the a2c objective - we insist on the same shapes a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask) return a2c_objective
def batches_stream(self): """Use the RLTask self._task to create inputs to the value model.""" for np_trajectory in self._task.trajectory_batch_stream( self._batch_size, max_slice_length=self._max_slice_length, epochs=[-1]): # Insert an extra depth dimension, so the target shape is consistent with # the network output shape. yield (np_trajectory.observations, # Inputs to the value model. np_trajectory.returns[:, :, None], np_trajectory.dones[:, :, None], np_trajectory.rewards[:, :, None], np_trajectory.actions, jnp.zeros_like(np_trajectory.mask), np_trajectory.mask)
def ParametricRelu(a=1.): r"""Returns a layer that computes a ReLU function with the given slope. .. math:: f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ ax & \text{otherwise}. \end{array} \right. Args: a: Slope of line for positive inputs. """ return Fn('ParametricRelu', lambda x: jnp.maximum(a * x, jnp.zeros_like(x)))
def favor_denominator_bwd(qkp, r_ct): precision, qs, ks, p = qkp def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) _, (qs_ct, ks_ct) = fastmath.scan( body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (None, None, qs_ct, ks_ct)
def favor_numerator_bwd(pqkv, w_ct): precision, p, qs, ks, vs = pqkv def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= jnp.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) _, (qs_ct, ks_ct, vs_ct) = fastmath.scan( body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return (None, None, qs_ct, ks_ct, vs_ct)
def favor_denominator_bwd(init_prefix_sum_value, precision, qkp, r_ct): del init_prefix_sum_value def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) qs, ks, p = qkp _, (qs_ct, ks_ct) = fastmath.scan(body, (p, np.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (qs_ct, ks_ct)
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `queries` and `keys`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values. mode: One of `'train'`, `'eval'`, or `'predict'`. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if fastmath.is_backend(fastmath.Backend.JAX): mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def favor_numerator_bwd(init_prefix_sum_value, precision, pqkv, w_ct): del init_prefix_sum_value def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= np.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) p, qs, ks, vs = pqkv _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, np.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return qs_ct, ks_ct, vs_ct
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head ``queries`` and ``keys``, - applies ``mask`` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head ``values`` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def _calc_attn_scores(q, k): ac = jnp.einsum('bnid,bnjd->bnij', q + context_bias, k) bd = jnp.einsum('bnid,jnd->bnij', q + location_bias, pos_emb) if mode != 'predict': bd = _fast_matrix_shift(bd) dots = (ac + bd) / jnp.sqrt(d_feature) dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) return dots
def _update_sketched(self, grads, weights, m, v, opt_params): """Update for higher-rank parameters.""" learning_rate = opt_params['learning_rate'] momentum = opt_params['momentum'] shape = weights.shape rank = len(shape) reshaped_accumulators = [jnp.reshape(v[i], self._expanded_shape(shape, i)) for i in range(rank)] current_accumulator = self._minimum(reshaped_accumulators) current_accumulator += grads * grads accumulator_inv_sqrt = jnp.where(current_accumulator > 0.0, 1.0 / jnp.sqrt(current_accumulator), jnp.zeros_like(current_accumulator)) preconditioned_gradient = grads * accumulator_inv_sqrt m = (1.0 - momentum) * preconditioned_gradient + momentum * m weights = weights - (learning_rate * m).astype(weights.dtype) for i in range(len(v)): axes = list(range(int(i))) + list(range(int(i) + 1, rank)) dim_accumulator = jnp.amax(current_accumulator, axis=axes) v[i] = dim_accumulator return weights, (m, v)
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): """Splits a key into a stream of random keys. This uses the little-endian counter mode. Args: key: uint32[2] the key to split lo: the range to start extracting from hi: the range to stop extracting from Returns: keys: uint32[hi - lo, 2] the split keys """ if not (key.shape == (2, ) and key.dtype == jnp.uint32): raise ValueError('key must be uint32[2]') if not hi < 2**32: # You shouldn't really be using more than half the key size anyways. raise NotImplementedError('only 32-bit sizes are supported') # Create a 64-bit counter: i_lo = jnp.arange(lo, hi, dtype=jnp.uint32) i_hi = jnp.zeros_like(i_lo) i = jnp.stack([i_lo, i_hi], axis=-1) return threefry_2x32_prf(key, i)
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, w1, w2, b2 = self.weights x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: check if we need bias and/or put relu after the m1 dot? mask_logits = jnp.dot(x, m1) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. # TODO(lukaszkaiser, chowdhery): Extract this block and share rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(selected_experts, self._num_experts) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (50% of the batches) use the soft-mask instead of # the quantized mask to improve training stability (see the paper above). # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) quant_mask = jnp.where(select > 0.0, quant_mask, mask) else: quant_mask = tl.one_hot(selected_experts, self._num_experts) quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1]) quant_mask_shape = quant_mask.shape batch_size = quant_mask.shape[0] if self._mode == 'predict' and batch_size == 1: # This implementation mimicks inference for batch_size 1. start_idx = selected_experts[0] * self._n_elements_in_block # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] w = fastmath.dynamic_slice( w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block]) mid = jnp.dot(x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] v = fastmath.dynamic_slice( w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]]) v = jnp.reshape(v, [self._n_elements_in_block, -1]) res = jnp.dot(relu, v) + b2 else: expanded_mask = jnp.broadcast_to( quant_mask, (quant_mask_shape[0], quant_mask.shape[1], self._n_elements_in_block)) expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
def init(self, weights): vs = [jnp.zeros(sz, dtype=weights.dtype) for sz in weights.shape] return (jnp.zeros_like(weights), vs)
def relu(x): return jnp.where(x <= 0, jnp.zeros_like(x), x)
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, m2, mb, w1, w2, b2 = self.weights if self._mode != 'predict': w1 = jnp.reshape(w1.T, (-1, self._d_ff)) w2 = jnp.reshape(w2, (self._d_ff, -1)) x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: should we add bias and/or put relu after the low-rank m1 dot? mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2]) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (quant_prob of the batches) use the soft-mask instead # of the quantized mask to improve training stability (see paper above). select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0) quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) if self._mode == 'train': # In training, run full matmul to get benefits from the above tricks. mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 elif self._mode == 'predict': # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1)) # w2 = jnp.reshape(w2, (self._d1, self._d2, -1)) # This implementation mimicks inference. It's not efficient for large # size of joint_batch, but at inference that will be 1 most of the time. # Shapes: # quant_mask is [joint_batch, self._d1] # w1 is [d_model, self._d1, self._d2] # we'll index w1 with advanced numpy indexing, first range over # self._d1 times the batch size, second range being quant_mask batch_size = quant_mask.shape[0] idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) # flatten indices and select from w1 idx1 = jnp.reshape(idx1, [-1]) idx2 = jnp.reshape(quant_mask, [-1]) w = w1[idx1, idx2, :] # now we have per-element weights with batch dim w = jnp.reshape(w, [batch_size, self._d1, -1]) mid = jnp.einsum('ai,aji->aj', x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [self._d1, self._d2, d_model] v = w2[idx1, idx2, :] v = jnp.reshape(v, [batch_size, self._d1, -1]) res = jnp.einsum('ai,aij->aj', relu, v) + b2 else: quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
def backward(self, inputs, output, grad, weights, state, new_state, rng): return (jnp.zeros_like(grad), ())