def _run_value_model(self, observations, dist_inputs): if dist_inputs is None: dist_inputs = jnp.zeros(observations.shape[:2] + (self._policy_dist.n_inputs, )) actions = None if self._q_value: if self._sample_all_discrete_actions: # Since we want to sample all actions, start by creating their list. act = np.arange(self._vocab_size) # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it. # Add extra dimenstions so it's the same dimensionality as dist_inputs. act = jnp.reshape(act, [-1] + [1] * (len(dist_inputs.shape) - 1)) # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs. dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) if self._sample_all_discrete_actions: actions = act + jnp.zeros(dist_inputs.shape[:-1], dtype=jnp.int32) actions = jnp.swapaxes(actions, 0, 1) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) if not self._sample_all_discrete_actions: actions = self._policy_dist.sample(dist_inputs) log_probs = self._policy_dist.log_prob(dist_inputs, actions) obs = observations obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) inputs = (obs, actions) else: log_probs = None inputs = (observations, ) n_devices = fastmath.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng values, _ = self._value_eval_jit(inputs, weights, state, rng) values *= self._value_network_scale values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. return (values, actions, log_probs)
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer's `filters` value, and the second to last dimension is shrinked if 'VALID' padding is used with kernel_size bigger than one. """ if self._use_bias: if not isinstance(self.weights, (tuple, list)): raise ValueError(f'Weights should be a (w, b) tuple or list; ' f'instead got: {self.weights}') w, b = self.weights else: w = self.weights linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w) # TODO(jaszczur): this could be run after padding for better efficiency if self._kernel_size == 1: # With kernel size 1 we don't have to split or shift anything. linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) else: # We computed a result for every "pixel", but each direction from the # receptive field (there are 'self._kernel_size' such directions) must be # shifted by a different amount. The easiest way to do it is to split # the tensor to 'self._kernel_size' smaller tensors, shift each one # appropriately, and then sum them together. split_shifting_linear_results = jnp.split( linear_results_before_shifting, self._kernel_size, axis=-2) for i in range(self._kernel_size): # Each tensor has to be shifted a different amount. if self._padding == 'WRAP': # We can shift by padding and cutting. With 'wrap' padding we # essentially have a torus. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding, mode='wrap') split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] elif self._padding == 'SAME': # We can shift by padding and cutting. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding) split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] # TODO(jaszczur): improve efficiency by not padding things to cut elif self._padding == 'VALID': # We don't need to shift - just cut the leftmost and rightmost values. cut_left = (self._kernel_size - 1) - i cut_right = split_shifting_linear_results[i].shape[-3] - i split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., cut_left:cut_right, :, :] else: raise ValueError(f'Invalid padding {self._padding}') # After shifting. shifted_linear_results = jnp.concatenate( split_shifting_linear_results, axis=-2) linear_result = jnp.sum(shifted_linear_results, axis=-2) if self._use_bias: return linear_result + b else: return linear_result
def ReformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch( [], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def BERT( d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval', ): """BERT (default hparams are for bert-base-uncased).""" layer_norm_eps = 1e-12 d_head = d_model // n_heads word_embeddings = tl.Embedding(vocab_size, d_model) type_embeddings = tl.Embedding(type_vocab_size, d_model) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. tl.Parallel(word_embeddings, type_embeddings, [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1) ]), tl.Add(), position_embeddings, tl.LayerNorm(epsilon=layer_norm_eps), ] encoder = [] for _ in range(n_layers): attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head, bias=True, masked=True, mode=mode) feed_forward = [tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model)] encoder += [ tl.Select([0, 1, 1]), # Save a copy of the mask tl.Residual(attn, AddBias()), # pylint: disable=no-value-for-parameter tl.LayerNorm(epsilon=layer_norm_eps), tl.Residual(*feed_forward), tl.LayerNorm(epsilon=layer_norm_eps), ] encoder += [tl.Select([0], n_in=2)] # Drop the mask pooler = [ tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2), tl.Dense(d_model), tl.Tanh(), ] init_checkpoint = init_checkpoint if mode == 'train' else None bert = PretrainedBERT(embeddings + encoder + pooler, init_checkpoint=init_checkpoint) if head is not None: bert = tl.Serial(bert, head()) return bert
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape='fixed-base', d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, n_decoder_attention_layers=2, use_bfloat16=False, reversible_encoder=False, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size n_layers_forget: how often to have a forgetting block between layers n_decoder_attention_layers: how many attention layers in a decoder block use_bfloat16: whether to use bfloat16 for weights (default: False) reversible_encoder: whether to be reversible through the encoder mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if d_attention_key is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_key = d_model // n_heads if d_attention_value is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_value = d_model // n_heads # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs, use_bfloat16=use_bfloat16)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, use_bfloat16=use_bfloat16, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = [ # vec_e mask_e tok_e tok_d tok_d tl.ReversibleSelect([0, 0]), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget) ] if not reversible_encoder: encoder += [ tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model, use_bfloat16=use_bfloat16), tl.LayerNorm(), ] encoder = tl.Serial(encoder) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) dense_loss_layer = tl.SparseDenseWithOptions( output_vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) # Layers to merge encoder and decoder, see below for details. if reversible_encoder: encdec_layers = [ tl.ReversibleSelect([0, 1, 4, 2, 3]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding2(mode=mode), # vec_ed vec_ed tok_e tok_d ] else: encdec_layers = [ tl.ReversibleSelect([0, 3, 1, 2]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d tl.ReversibleSelect([0, 0]), # vec_ed vec_ed tok_e tok_d ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), tl.Parallel([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Concat encoder and decoder, given encoder mask. encdec_layers, # Run decoder blocks. _ReversibleSerialForget( decoder_blocks, d_model, n_layers_forget), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. dense_loss_layer, # vec_d tok_d )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train', axial_pos_shape=None, d_axial_pos_embs=None, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode=mode, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) for _ in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... )
def _RemoveAxes12(): """Returns a layer that removes two internal size-1 axes from an array.""" return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2)))
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, pos_type='fixed-base', pos_axial_shape=(), pos_d_axial_embs=None, pos_start_from_zero_prob=1.0, pos_max_offset_to_add=0, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, forget_dense=True, n_decoder_attention_layers=2, use_bfloat16=False, reversible_encoder=False, use_two_swaps_per_encoder_block=True, center_layernorm=True, half_before_layer=None, double_after_layer=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. pos_start_from_zero_prob: how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize). pos_max_offset_to_add: maximum offset to add to positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size n_layers_forget: how often to have a forgetting block between layers forget_dense: whether to use Dense or no-op (Serial) as a forget layer. n_decoder_attention_layers: how many attention layers in a decoder block use_bfloat16: whether to use bfloat16 for weights (default: False) reversible_encoder: whether to be reversible through the encoder use_two_swaps_per_encoder_block: whether to allow even number of swaps in the encoder center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. half_before_layer: int, half d_model and d_ff before that layer double_after_layer: int, double d_model and d_ff after that layer mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if (d_model / 2) % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})') if d_attention_key is None: d_attention_key = d_model // n_heads if d_attention_value is None: d_attention_value = d_model // n_heads # Set values of d_model, d_ff and d_qkv for the first stage. d_model1, d_ff1 = d_model, d_ff d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value if half_before_layer: d_model1, d_ff1 = d_model / 2, d_ff / 2 d_attention_key1 = d_attention_key / 2 d_attention_value1 = d_attention_value / 2 # Set values of d_model, d_ff and d_qkv for the final stage. d_model2, d_ff2 = d_model, d_ff d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value if double_after_layer: d_model2, d_ff2 = d_model * 2, d_ff * 2 d_attention_key2 = d_attention_key * 2 d_attention_value2 = d_attention_value * 2 # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model1, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, pos_type=pos_type, pos_axial_shape=pos_axial_shape, pos_d_axial_embs=pos_d_axial_embs, pos_start_from_zero_prob=pos_start_from_zero_prob, pos_max_offset_to_add=pos_max_offset_to_add, use_bfloat16=use_bfloat16)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model1, d_ff1, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, use_two_swaps_per_block=use_two_swaps_per_encoder_block, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = [ # vec_e mask_e tok_e tok_d tok_d tl.ReversibleSelect([0, 0]), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model1, n_layers_forget, forget_dense) ] if not reversible_encoder: encoder += [ tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model1, use_bfloat16=use_bfloat16), tl.LayerNorm(), ] encoder = tl.Serial(encoder) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] # Grow d_model, d_ff, and d_qkv if requested. d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 if half_before_layer and layer_idx >= half_before_layer: d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value if double_after_layer and layer_idx > double_after_layer: d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 decoder_block = DecoderBlock( d_m, d_f, d_k, d_v, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) if half_before_layer and layer_idx == half_before_layer - 1: decoder_blocks.append(tl.ReversibleConcatenatePair()) if double_after_layer and layer_idx == double_after_layer: decoder_blocks.append(tl.ReversibleConcatenatePair()) dense_loss_layer = tl.SparseDenseWithOptions( output_vocab_size, d_input=d_model2, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) # Layers to merge encoder and decoder, see below for details. if reversible_encoder: encdec_layers = [ tl.ReversibleSelect([0, 1, 4, 2, 3]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding2(mode=mode), # vec_ed vec_ed tok_e tok_d ] else: encdec_layers = [ tl.ReversibleSelect([0, 3, 1, 2]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d tl.ReversibleSelect([0, 0]), # vec_ed vec_ed tok_e tok_d ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), # Predict mode doesn't work with padding in encoder. Raising an exception # in jitted function isn't possible, so the second next best thing is # to convert every embedding to NaNs, so the user will not get subtly # wrong results, but clearly wrong results. (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []), tl.Parallel([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Concat encoder and decoder, given encoder mask. encdec_layers, # Run decoder blocks. _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget, forget_dense), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. dense_loss_layer, # vec_d tok_d )
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape='fixed-base', d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, n_layers_forget=0, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity n_layers_forget: how often to have a forgetting block between layers mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if d_attention_key is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_key = d_model // n_heads if d_attention_value is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_value = d_model // n_heads # Vector embeddings. def Embedder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), ] in_embedder = Embedder(input_vocab_size) out_embedder = (in_embedder if output_vocab_size is None else Embedder(output_vocab_size)) def PositionalEnc(mode): return PositionalEncoding(mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. encoder_mode = 'eval' if mode == 'predict' else mode in_encoder = in_embedder + [PositionalEnc(encoder_mode)] out_encoder = out_embedder + [PositionalEnc(mode)] if output_vocab_size is None: output_vocab_size = input_vocab_size # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), tl.Parallel([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # vec_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given encoder mask. tl.Select([1, 0]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d _ReversibleSerialForget( decoder_blocks, d_model, n_layers_forget), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder(input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def _unshard_fn(x): y = jax.lax.all_gather(x, 'batch', axis_index_groups=groups) split_y = jnp.split(y, n_shards, axis=0) split_y = [jnp.squeeze(sy, axis=0) for sy in split_y] axis = _axis_to_shard_heuristic(split_y[0].shape) return jnp.concatenate(split_y, axis=axis)