def clipped_objective_mean(self): def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" del dones, rewards advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) # advantages are of the shape [128,1,1] # and probs_ratio are of the shape [128,1] advantages = advantages.squeeze(axis=2) clipped_objective = rl_layers.ClippedObjective( probs_ratio, advantages, epsilon=self._epsilon) return jnp.mean(clipped_objective) return tl.Fn('ClippedObjectiveMean', f)
def PickLastTokenInPredict(mode='train'): """Picks the last token logits. Self-descriptive layer for picking the last token logits in predict mode for fast inference. Args: mode: the model mode (train, predict, ...) Returns: The last token logits. """ def last_token(x): # pylint: disable=invalid-name if mode == 'predict': return x[:, -1:, :] return x return tl.Fn('Pick last token in predict', last_token)
def ppo_objective_mean(self): """PPO objective mean.""" def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" ppo_objective = rl_layers.PPOObjective( dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) return jnp.mean(ppo_objective) return tl.Fn('PPOObjectiveMean', f)
def SignificanceWeights(serializer, decay): """Multiplies a binary mask with a symbol significance mask.""" def significance_weights(mask): # (repr,) -> (batch, length, repr) # significance = [0, 1, 2] significance = serializer.significance_map assert significance.shape[0] == mask.shape[2] # significance = batch_size * [0, 1, 2] significance = jnp.repeat( significance[np.newaxis, ...], repeats=mask.shape[0], axis=0) # significance = batch_size * [0, 1, 2] * mask.shape[1] significance = jnp.repeat( significance[..., jnp.newaxis], repeats=mask.shape[1], axis=2) # significance = batch_size * mask.shape[1] * [0, 1, 2] significance = jnp.swapaxes(significance, 1, 2) assert significance.shape == mask.shape sig_weights = mask * decay ** significance return sig_weights return tl.Fn('SignificanceWeights', significance_weights)
def value_loss(self): """Value loss computed using smooth L1 loss or L2 loss.""" def f(values, actions, returns, mask): ind_0, ind_1 = np.indices(actions.shape) # We calculate length using the shape of returns # and adequatly remove a superflous slice of values. # An analogous operation is done in value_batches_stream. length = returns.shape[1] values = values[:, :length, :] selected_values = values[ind_0, ind_1, actions] shapes.assert_same_shape(selected_values, returns) shapes.assert_same_shape(selected_values, mask) if self._smoothl1loss: return tl.SmoothL1Loss().forward( (selected_values, returns, mask)) else: return tl.L2Loss().forward((selected_values, returns, mask)) return tl.Fn('ValueLoss', f)
def _MaskOfRightShiftedArray(n_shifts=1, mode='train'): """Gives us the mask of a right shifted by n_shifts array.""" def F(x): # TODO(afrozm): What to do in this case? if mode == 'predict': raise ValueError( 'MaskOfRightShiftedArray not implemented for predict.') mask = x != 0 if n_shifts == 0: return mask # Need to set (B, n_shifts, ...) section to True. trues_shape = (x.shape[0], n_shifts) + mask.shape[2:] trues = jnp.full(trues_shape, True) return jnp.concatenate([trues, mask[:, n_shifts:, ...]], axis=1) return tl.Fn(f'MaskOfRightShiftedArray({n_shifts})', F)
def Deinterleave(x_size, y_size): """Layer that does the inverse of Interleave.""" def deinterleave(inputs): reprs = inputs (batch_size, length) = reprs.shape[:2] shape_suffix = reprs.shape[2:] remainder_length = length % (x_size + y_size) if remainder_length > 0: remainder = reprs[:, None, -remainder_length:] reprs = reprs[:, :-remainder_length] reprs = jnp.reshape(reprs, (batch_size, -1, x_size + y_size) + shape_suffix) x_reprs = reprs[:, :, :x_size] y_reprs = reprs[:, :, x_size:] if remainder_length > 0: x_reprs = jnp.concatenate((x_reprs, remainder), axis=1) return (x_reprs, y_reprs) return tl.Fn('Deinterleave', deinterleave, n_out=2)
def test_forward(self): layer = tl.Fn( 'SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2) x0 = np.array([1, 2, 3, 4, 5]) x1 = np.array([10, 20, 30, 40, 50]) y0, y1 = layer((x0, x1)) self.assertEqual(y0.tolist(), [11, 22, 33, 44, 55]) self.assertEqual(y1.tolist(), [10, 20, 30, 40, 50]) y2, y3 = layer.forward((x0, x1)) self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55]) self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50]) (y4, y5), state = layer.pure_fn((x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE, None) self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55]) self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50]) self.assertEqual(state, tl.EMPTY_STATE)
def a2c_objective_mean(self): """A2C objective mean.""" def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask): """A2C objective mean.""" # TODO(henrykm): include dones, rewards del old_log_probs a2c_objective = rl_layers.A2CObjective( dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun=self._policy_dist.log_prob, normalize_advantages=self._normalize_advantages) return jnp.mean(a2c_objective) return tl.Fn('A2CObjectiveMean', f, n_out=1)
def policy_loss(self, **unused_kwargs): """Policy loss.""" def LossInput(dist_inputs, actions, advantages, old_dist_inputs, mask): # pylint: disable=invalid-name """Calculates action log probabilities and normalizes advantages.""" del old_dist_inputs advantages = self._preprocess_advantages(advantages) dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape ) log_probs = self._policy_dist.log_prob(dist_inputs, actions) # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...) advantages = jnp.swapaxes(advantages, 0, 1) mask = jnp.swapaxes(mask, 0, 1) return (log_probs, advantages, log_probs, mask) return tl.Serial( tl.Fn('LossInput', LossInput, n_out=4), # Policy loss is expected to consume # (log_probs, advantages, old_log_probs, mask). AWRLoss(beta=self._beta, w_max=self._w_max), # pylint: disable=no-value-for-parameter )
def policy_loss(self, **unused_kwargs): """Policy loss.""" def normalize(adv): return ((adv - jnp.mean(adv)) / (jnp.std(adv) + self._advantage_normalization_epsilon)) def LossInput(dist_inputs, actions, advantages, old_dist_inputs): # pylint: disable=invalid-name """Calculates action log probabilities and normalizes advantages.""" if self._advantage_normalization: advantages = normalize(advantages) log_probs = self._policy_dist.log_prob(dist_inputs, actions) old_log_probs = self._policy_dist.log_prob(old_dist_inputs, actions) return (log_probs, advantages, old_log_probs) return tl.Serial( tl.Fn('LossInput', LossInput, n_out=3), # Policy loss is expected to consume # (log_probs, advantages, old_log_probs, mask). self.policy_loss_given_log_probs, )
def MultiplicativeModularSparseDense(sparsity, d_feature): """Returns a replacement of Dense layer which uses less parameters. The layer uses number of modules equal to `sparsity`. It is a combination of multiplicative dense and locally connected dense layers. Args: sparsity: The sparsity of the layer; the output vector is divided into this number of modules. d_feature: Dimensionality of input and output tensor. """ assert d_feature % sparsity == 0 d_module = d_feature // sparsity return tl.Serial( # Weight below is used for per-head preprocessing of an embedding. tl.Weights(init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_feature]), # Weight below is a kernel of multiplicative dense, shared across heads. tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]), # Weight below is a kernel of modular dense. tl.Weights( functools.partial(init.GlorotUniformInitializer(), nonreceptive_dims=[0]), [sparsity, d_module, d_module]), # To save memory the per-head preprocessing and multiplying by # kernels is done in a single einsum. tl.Fn( 'SparseDenseEinsum', ( lambda kmod, kmult, multiplier, embeds: # pylint: disable=g-long-lambda jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier, embeds))), MergeLastTwoAxes(), # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), tl.Add(), )
def MultiplicativeSparseDense(sparsity, d_input, d_output=None): # pylint: disable=invalid-name """Returns a replacement of Dense layer which uses less parameters. The layer uses number of modules equal to `sparsity`. It multiplies each dimension of the input tensor by a scalar specific to each dimension and each module separately; then it applies Dense(d_output/sparsity) to each module. Compared to standard dense layer, MultiplicativeSparseDense uses less parameters while still being able to express many interesting functions (for example a permutation). Args: sparsity: The sparsity of the layer; the output vector is divided into this number of modules. d_input: Dimensionality of input tensor. d_output: Dimensionality of output tensor; by default equal to d_input. """ assert d_output % sparsity == 0 d_module = d_output // sparsity return tl.Serial( # Weight below is used for per-head preprocessing of an embedding. tl.Weights(init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_input]), # Weight below is dense kernel, shared across heads. tl.Weights(init.GlorotUniformInitializer(), [d_input, d_module]), # To save memory the per-head preprocessing and multiplying by the # kernel is done in the same einsum. tl.Fn( 'AttentionEinsum', ( lambda kernel, multiplier, embeds: # pylint: disable=g-long-lambda np.einsum('dx,hd,bld->blhx', kernel, multiplier, embeds))), MergeLastTwoAxes(), # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), tl.Add(), )
def _StripFromConcatenateWithPadding(): """Strips out the leading encoder tokens from the concatenated array.""" def _StripEncToks(vec_ed, tok_e, tok_d): # pylint: disable=invalid-name B, L, H = vec_ed.shape L1 = tok_e.shape[1] L2 = tok_d.shape[1] # pylint: enable=invalid-name if L != L1 + L2: raise ValueError( f'Length from encoder-decoder vectors ({L}) does not' f' equal sum of lengths from encoder ({L1}) and decoder' f' ({L2}).') if tok_e.shape != (B, L1): raise ValueError( f'Shape of encoder tokens, {tok_e.shape}, does not' f' equal {(B, L1)}.') if tok_d.shape != (B, L2): raise ValueError( f'Shape of decoder tokens, {tok_d.shape}, does not' f' equal {(B, L2)}.') def _UpdateRow(x): # (L, H), (L1, H) & (L2, H) row_ed, row_e, _ = x mask_e = row_e != 0 len_e = jnp.sum(mask_e, dtype=jnp.int32) # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e` # and pick up (L2, H) tensor slice from there. zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch l2_np = jnp.array(L2, dtype=len_e.dtype) h_np = jnp.array(H, dtype=len_e.dtype) return jax.lax.dynamic_slice(row_ed, (len_e, zero), (l2_np, h_np)) return jax.lax.map(_UpdateRow, [vec_ed, tok_e, tok_d]) return tl.Fn('StripFromConcatenateWithPadding', _StripEncToks, n_out=1)
def _Upsampler(total_pool_size, separate_cls): """Returns an upsampling layer for Funnel Transformer. Args: total_pool_size: The combined pool size of previously used funnel blocks. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper). """ def _Upsample(short, long): if separate_cls: upsampled_short = jnp.concatenate( (short[:, :1, :], short[:, 1:, :].repeat(total_pool_size, axis=1)), axis=1) return index_add(long, (slice(None), slice( None, upsampled_short.shape[1]), slice(None)), upsampled_short) else: upsampled_short = short.repeat(total_pool_size, axis=1) return long + upsampled_short return tl.Fn('Upsampler', _Upsample)
def siamese(vocab_size, d_model=128): """Returns a Siamese model. Args: vocab_size (int, optional): Length of the vocabulary. Defaults to len(vocab). d_model (int, optional): Depth of the model. Defaults to 128. Returns: trax.layers.combinators.Parallel: A Siamese model. """ def normalize(vec): # normalizes the vectors to have L2 norm 1 return vec / fastnp.sqrt(fastnp.sum(vec * vec, axis=-1, keepdims=True)) s_processor = tl.Serial( tl.Embedding(vocab_size, d_model), # Embedding layer tl.LSTM(d_model), # LSTM layer tl.Mean(axis=1), # Mean over columns tl.Fn('Normalize', normalize) # Apply normalize function ) # Returns one vector of shape [batch_size, d_model]. # Run on s1_tensor and s2_tensor in parallel. model = tl.Parallel(s_processor, s_processor) return model
def _StripFromConcatenateWithPadding(): """Strip out the leading encoder tokens from the concatenated array.""" def _StripEncToks(vec_ed, tok_e, tok_d): # pylint: disable=invalid-name B, L, H = vec_ed.shape L1 = tok_e.shape[1] L2 = tok_d.shape[1] # pylint: enable=invalid-name assert L == L1 + L2 assert (B, L1) == tok_e.shape assert (B, L2) == tok_d.shape def _UpdateRow(x): # (L, H), (L1, H) & (L2, H) row_ed, row_e, _ = x mask_e = row_e != 0 len_e = jnp.sum(mask_e, dtype=jnp.int32) # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e` # and pick up (L2, H) tensor slice from there. return jax.lax.dynamic_slice(row_ed, (len_e, 0), (L2, H)) return jax.lax.map(_UpdateRow, [vec_ed, tok_e, tok_d]) return tl.Fn('StripFromConcatenateWithPadding', _StripEncToks, n_out=1)
def _StripFromConcatenateWithPadding(): """Strip out the leading encoder tokens from the concatenated array.""" # Shapes: (L1+L2, H), (L1,) and (L2,) def F(vec_ed, tok_e, tok_d): mask_e = tok_e != 0 # Actual length of encoder tokens <= L1 len_e = jnp.sum(mask_e) # Padded length of decoder tokens, this is L2. L2 = tok_d.shape[0] # pylint: disable=invalid-name # vec_ed is of type [eeedd00000], we will roll it len_e=3 in reverse. # This gives us [dd00000eee] and now we take only the first L2 elements. return jnp.roll(vec_ed, -len_e, axis=0)[:L2] # TODO(afrozm): Try to do this with sort_key_val instead of roll to get rid of # the vmap. def _F(vec_ed, tok_e, tok_d): return jax.vmap(F)(vec_ed, tok_e, tok_d) # We could have written `tl.Fn(..., jax.vmap(F), ...)` here but Trax needs the # top-level function (here: jax.vmap) to not have variable or named arguments, # so we need a thin wrapper. return tl.Fn('StripFromConcatenateWithPadding', _F, n_out=1)
def BERT(d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval', ): """BERT (default hparams are for bert-base-uncased).""" layer_norm_eps = 1e-12 d_head = d_model // n_heads word_embeddings = tl.Embedding(d_model, vocab_size) type_embeddings = tl.Embedding(d_model, type_vocab_size) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. tl.Parallel( word_embeddings, type_embeddings, [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)] ), tl.Add(), position_embeddings, tl.LayerNorm(epsilon=layer_norm_eps), ] encoder = [] for _ in range(n_layers): attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head, bias=True, masked=True, mode=mode) feed_forward = [ tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model) ] encoder += [ tl.Select([0, 1, 1]), # Save a copy of the mask tl.Residual(attn, AddBias()), # pylint: disable=no-value-for-parameter tl.LayerNorm(epsilon=layer_norm_eps), tl.Residual(*feed_forward), tl.LayerNorm(epsilon=layer_norm_eps), ] encoder += [tl.Select([0], n_in=2)] # Drop the mask pooler = [ tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2), tl.Dense(d_model), tl.Tanh(), ] init_checkpoint = init_checkpoint if mode == 'train' else None bert = PretrainedBERT( embeddings + encoder + pooler, init_checkpoint=init_checkpoint) if head is not None: bert = tl.Serial(bert, head()) return bert
def LogProb(self): # pylint: disable=invalid-name """Builds a log probability layer for this distribution.""" return tl.Fn('LogProb', lambda inputs, point: self.log_prob(inputs, point)) # pylint: disable=unnecessary-lambda
def joint_loss(self): """Joint policy and value loss.""" def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask): """Definition of the Proximal Policy Optimization loss.""" del mask # TODO(lukaszkaiser): make PPO work with Transformer # We have dist_inputs of the shape float32[128,1,18] assert len(dist_inputs.shape) == 3, ( f'dist_inputs.shape was {dist_inputs.shape}' f'but expected length of the tensor shape is 3') # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape int32[128,1,1] # rewards of the shape float32[128,1,1] # and old_log_probs of the shape float32[128,1] assert values.shape == returns.shape, ( f'values.shape was {values.shape}' f'returns.shape was {returns.shape}') assert values.shape == dones.shape, ( f'values.shape was {values.shape}' f'returns.shape was {dones.shape}') assert rewards.shape == dones.shape, ( f'values.shape was {values.shape}' f'returns.shape was {dones.shape}') assert returns.shape[0:2] == old_log_probs.shape, ( f'returns.shape was {returns.shape}' f'old_log_probs.shape was {old_log_probs.shape}') # actions is a tensor of the shape int32[128,1] in the case # of discrete actions and float32[128,1,6] in the case of # half-cheetah and other continuous actions # actions agree with returns/values on the first two coordinates # meaning batch and time assert actions.shape[0:2] == returns.shape[0:2], ( f'actions.shape was {actions.shape} and ' f'returns.shape was {returns.shape}') ppo_objective = rl_layers.PPOObjective( dist_inputs, stop_gradient(values), returns, dones, rewards, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) # we insist that ppo_objective is a vector of shape [128,1] assert len(ppo_objective.shape) == 2, ( f'ppo_objective was {ppo_objective}') # which agrees with returns/values/actions on the first two coordinates assert ppo_objective.shape[0:2] == values.shape[0:2], ( f'ppo_objective.shape was {ppo_objective.shape} and ' f'values.shape was {values.shape}') entropy_loss = rl_layers.EntropyLoss( dist_inputs, distribution=self._policy_dist, coeff=self._entropy_coeff, ) assert jnp.ndim(entropy_loss) == 0, f'entropy_loss was {entropy_loss}' l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) assert jnp.ndim(l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' return -ppo_objective.mean() + l2_value_loss - entropy_loss return tl.Fn('PPOJointLoss', f)
def preferred_move(self): """Preferred move - the mean of selected moves.""" def f(dist_inputs, values): del values return rl_layers.PreferredMove(dist_inputs, self._policy_dist.sample) return tl.Fn('PreferredMove', f)
def log_probs_mean(self): """Mean of log_probs aka dist_inputs.""" def f(dist_inputs, values): del values return jnp.mean(dist_inputs) return tl.Fn('LogProbsMean', f)
def explained_variance(self): """Explained variance metric.""" def f(dist_inputs, values, returns): del dist_inputs return rl_layers.ExplainedVariance(values, returns) return tl.Fn('ExplainedVariance', f)
def value_loss(self): """Value loss - so far generic for all A2C.""" def f(dist_inputs, values, returns): del dist_inputs return rl_layers.ValueLoss(values, returns, self._value_loss_coeff) return tl.Fn('ValueLoss', f)
def advantage_norm(self): """Norm of advantages.""" def f(dist_inputs, values, returns): del dist_inputs return jnp.linalg.norm(returns - values) return tl.Fn('AdvantageNorm', f)
def advantage_mean(self): """Mean of advantages.""" def f(dist_inputs, values, returns): del dist_inputs return jnp.mean(returns - values) return tl.Fn('AdvantageMean', f)
def advantage_std(self): return tl.Serial([ # (dist_inputs, advantages, old_dist_inputs, mask) tl.Select([1]), # Select just the advantages. tl.Fn('AdvantageStd', lambda x: jnp.std(x)), # pylint: disable=unnecessary-lambda ])
def make_metric(aggregate_fn): # pylint: disable=invalid-name def AdvantageMetric(policy_inputs, actions, advantages, mask): del policy_inputs, actions, mask return aggregate_fn(advantages) return tl.Fn('AdvantageMetric', AdvantageMetric)
def entropy_metric(self): def Entropy(policy_inputs, actions, advantages, mask): del actions, advantages, mask return jnp.mean(self._policy_dist.entropy(policy_inputs)) return tl.Fn('Entropy', Entropy)