def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """Compute log N(x | mu, sigma).""" a = mu.shape[-1] * jnp.log(2 * jnp.pi) _, b = jnp.linalg.slogdet(sigma) y = jnp.linalg.solve(sigma, x - mu) y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name """Compute log N(x | mu, eye(diag_sigma)).""" a = mu.shape[-1] * jnp.log(2 * jnp.pi) b = jnp.sum(jnp.log(diag_sigma), axis=-1) y = x - mu / diag_sigma y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def actor_loss(actions, advantage_weights, log_probab_actions_new, state=None): """Actor loss.""" lp = np.squeeze(log_probab_actions_new) b = len(lp) log_probs = np.squeeze(lp[np.arange(b)[np.newaxis, :], actions]) return -1.0 * np.mean(log_probs * advantage_weights), state
def _run_value_model(self, observations, dist_inputs): if dist_inputs is None: dist_inputs = jnp.zeros(observations.shape[:2] + (self._policy_dist.n_inputs, )) actions = None if self._q_value: dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) actions = self._policy_dist.sample(dist_inputs) log_probs = self._policy_dist.log_prob(dist_inputs, actions) obs = observations obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) inputs = (obs, actions) else: log_probs = None inputs = (observations, ) n_devices = math.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng values, _ = self._value_eval_jit(inputs, weights, state, rng) values *= self._value_network_scale values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. return (values, actions, log_probs)
def f(preds, values, returns, actions, mask): advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def AWRJointLoss(x, **unused_kwargs): # pylint: disable=invalid-name preds, values, returns, actions, mask = x advantages = jnp.squeeze(returns - values, axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def actor_loss(actions, advantage_weights, log_probab_actions_new, state=None): """Actor loss.""" # log_probab_actions_new's shape is (AB, 1, #C, #A), AB is actor batch. lp = jnp.squeeze(log_probab_actions_new, axis=1) AB, NC = actions.shape # pylint: disable=invalid-name log_probs = lp[jnp.arange(AB)[:, None], jnp.arange(NC)[None, :], actions] # TODO(afrozm): Clarify this. # log_probs are shaped (AB, #C), however advantage_weights are (AB,) return -1.0 * jnp.mean(log_probs * advantage_weights[:, None]), state
def dataset_to_stream(dataset, input_name): """Takes a tf.Dataset and creates a numpy stream of ready batches.""" for example in math.dataset_as_numpy(dataset): features = example[0] inp, out = features[input_name], example[1] mask = features['mask'] if 'mask' in features else None # All input-pipeline processing should be on CPU. with tf.device('cpu:0'): # Some accelerators don't handle uint8 well, cast to int. if isinstance(inp, np.uint8): inp = inp.astype(np.int32) if isinstance(out, np.uint8): out = out.astype(np.int32) if len(out.shape) > 1 and out.shape[-1] == 1: out = np.squeeze(out, axis=-1) yield (inp, out) if mask is None else (inp, out, mask)
def _run_value_model(self, observations, dist_inputs): if dist_inputs is None: dist_inputs = jnp.zeros(observations.shape[:2] + (self._policy_dist.n_inputs, )) actions = None if self._q_value: if self._sample_all_discrete_actions: # Since we want to sample all actions, start by creating their list. act = np.arange(self._vocab_size) # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it. # Add extra dimenstions so it's the same dimensionality as dist_inputs. act = jnp.reshape(act, [-1] + [1] * (len(dist_inputs.shape) - 1)) # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs. dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) if self._sample_all_discrete_actions: actions = act + jnp.zeros(dist_inputs.shape[:-1], dtype=jnp.int32) actions = jnp.swapaxes(actions, 0, 1) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) if not self._sample_all_discrete_actions: actions = self._policy_dist.sample(dist_inputs) log_probs = self._policy_dist.log_prob(dist_inputs, actions) obs = observations obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) inputs = (obs, actions) else: log_probs = None inputs = (observations, ) n_devices = math.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng values, _ = self._value_eval_jit(inputs, weights, state, rng) values *= self._value_network_scale values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. return (values, actions, log_probs)
def BERT(d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval', ): """BERT (default hparams are for bert-base-uncased).""" layer_norm_eps = 1e-12 d_head = d_model // n_heads word_embeddings = tl.Embedding(d_model, vocab_size) type_embeddings = tl.Embedding(d_model, type_vocab_size) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. tl.Parallel( word_embeddings, type_embeddings, [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)] ), tl.Add(), position_embeddings, tl.LayerNorm(epsilon=layer_norm_eps), ] encoder = [] for _ in range(n_layers): attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head, bias=True, masked=True, mode=mode) feed_forward = [ tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model) ] encoder += [ tl.Select([0, 1, 1]), # Save a copy of the mask tl.Residual(attn, AddBias()), # pylint: disable=no-value-for-parameter tl.LayerNorm(epsilon=layer_norm_eps), tl.Residual(*feed_forward), tl.LayerNorm(epsilon=layer_norm_eps), ] encoder += [tl.Select([0], n_in=2)] # Drop the mask pooler = [ tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2), tl.Dense(d_model), tl.Tanh(), ] init_checkpoint = init_checkpoint if mode == 'train' else None bert = PretrainedBERT( embeddings + encoder + pooler, init_checkpoint=init_checkpoint) if head is not None: bert = tl.Serial(bert, head()) return bert
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn(lambda x, y: (x+y)/2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ # tok_e mask tok_d ..... tl.PaddingMask(), tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]), # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn(lambda x, y: (x+y)/2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. # TODO(kitaev): dropout=0.0 for tl.PositionalEncoding matches trax # Transformer, but may not be the right option in general. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=0.0, mode=mode) return [ tl.Embedding(d_model, vocab_size), # TODO(kitaev): BroadcastedDropout? tl.Dropout(rate=dropout, mode=mode), positional_encoding, ] in_encoder = PositionalEncoder(input_vocab_size) out_encoder = (in_encoder if output_vocab_size is None else PositionalEncoder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, mode) for _ in range(n_encoder_layers)] encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, mode) for _ in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch( in_encoder, [tl.PaddingMask(), tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)] ), # vec_e mask tok_d ..... tl.Dup(), # vec_e1 vec_e2 mask tok_d ..... tl.ReversibleSerial(encoder_blocks), # vec_e1 vec_e2 mask tok_d ..... # The two sets of activations need to be reduced to one, in this case by # averaging them. Note that ReformerLM concatenates instead. Various # options (concat, average, add, keep only one, etc.) seem to perform # similarly. We don't concatenate here because we want exact parameter # parity with the standard Transformer. tl.Fn(lambda x, y: (x+y)/2.0), # vec_e mask tok_d ..... tl.LayerNorm(), # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn(lambda x, y: (x+y)/2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def ReformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch( [], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )