def ActionInjector(mode): if inject_actions: if is_discrete: action_encoder = tl.Embedding(vocab_size, inject_actions_dim) else: action_encoder = tl.Dense(inject_actions_dim) encoders = tl.Parallel( tl.Dense(inject_actions_dim), action_encoder, ) if multiplicative_action_injection: action_injector = tl.Serial( tl.Fn('TanhMulGate', lambda x, a: x * jnp.tanh(a)), tl.LayerNorm() # compensate for reduced variance ) else: action_injector = tl.Add() return tl.Serial( # Input: (body output, actions). encoders, action_injector, models.MLP( layer_widths=(inject_actions_dim, ) * inject_actions_n_layers, out_activation=True, flatten=False, mode=mode, )) else: return []
def ResidualZero(*layers, shortcut=None): """Wraps a series of layers with a ReZero-style residual connection. Instead of computing `(shortcut) + (output of layers)`, like in classical Residual connection, ResidualZero computes `(shortcut) + alpha * (output of layers)`, where `alpha` is a learnable scalar initialized with zero. Args: *layers: One or more layers, to be applied in series. shortcut: If None (the usual case), the Residual layer computes the element-wise sum of the stack-top input with the output of the layer series. If specified, the `shortcut` layer applies to a copy of the inputs and (elementwise) adds its output to the output from the main layer series. Returns: A layer representing a residual connection paired with a layer series. """ layers = _ensure_flat(layers) layer = layers[0] if len(layers) == 1 else tl.Serial(layers) # TODO(jaszczur): perhaps change inner Serial to Branch? return tl.Serial( tl.Branch( shortcut, tl.Serial( layer, tl.Weights( lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)), tl.Multiply())), tl.Add(), # pylint: disable=no-value-for-parameter )
def EinsumDense(d_input, d_output, use_bias): """Returns a reimplementation of Dense layer, using einsum. While this is an equivalent of a Dense layer, it seems to be faster when used in decoding if used with bias (see decoding_timing_test.py ). This layer can be removed when we understand better the reason for the difference in decoding speed. Args: d_input: Dimensionality of the input tensor. d_output: Dimensionality of the output tensor. use_bias: Whether to use bias. """ layers = [ tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]), tl.Fn( 'EinsumDense', ( lambda kernel, embeds: # pylint: disable=g-long-lambda jnp.einsum('xd,...d->...x', kernel, embeds))) ] if use_bias: layers.extend([ tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), tl.Add() ]) return tl.Serial(layers)
def __init__(self, pre_attention, attention, post_attention): self.pre_attention = tl.Serial( # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(pre_attention, [], []), ) assert hasattr(attention, 'forward_and_backward') self.attention = ApplyAttentionWrapper(attention) self.post_attention = tl.Parallel(post_attention, [], []) layers = [ self.pre_attention, self.attention, self.post_attention, tl.Parallel(tl.Add(), []), ] super(ReversibleAttentionHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [ self.pre_attention, self.attention, self.post_attention, self.subtract_top, ]
def test_add_div(self): layer = tl.Branch(tl.Add(), DivideBy(0.5)) xs = [np.array([1, 2, 3]), np.array([10, 20, 30])] ys = layer(xs) self.assertEqual(as_list(ys), [[11, 22, 33], [2, 4, 6]])
def test_run_reversible_same_as_default_extended(self): """Runs the reversible trainer, check results are the same as default.""" inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) # We want to test rng propagation too, so adding some dropout layers. first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) rev_layers1 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), tl.ReversibleSwap(), tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), tl.ReversibleSwap() ] mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) rev_layers2 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), tl.ReversibleSwap() ] loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), tl.LogSoftmax(), tl.CrossEntropyLoss()) model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] + rev_layers2 + [loss_layer]) rng_init = fastmath.random.get_prng(12) model.init(labeled_batch, rng=rng_init) optimizer_fn = optimizers.Adam # to test slots # Make 3 steps with the original trainer. optimizer = optimizer_fn() optimizer.tree_init(model.weights) trainer = optimizers.Trainer(model, optimizer) rng_step1 = fastmath.random.get_prng(7) rng_step2 = fastmath.random.get_prng(8) rng_step3 = fastmath.random.get_prng(9) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) first_layer_weights1 = first_layer.weights rev_layer12_weights1 = rev_layers1[2].weights mid_layer_weights1 = mid_layer.weights rev_layer20_weights1 = rev_layers2[0].weights loss_layer_weights1 = loss_layer.weights # Now make 3 steps with reversible trainer. model.init(labeled_batch, rng=rng_init) trainer = optimizers.ReversibleSerialTrainer( [(first_layer.sublayers, rev_layers1), (mid_layer.sublayers, rev_layers2)], loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) # Check that weights end up the same. self._assert_all_equal(loss_layer_weights1, loss_layer.weights) self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) self._assert_all_equal(mid_layer_weights1, mid_layer.weights) self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) self._assert_all_equal(first_layer_weights1, first_layer.weights)
def BERTPretrainingLoss(): nsp_loss = [ tl.Select([0, 2, 3], n_in=6), tl.WeightedCategoryCrossEntropy() ] mlm_loss = [ tl.Select([1, 4, 5], n_in=6), tl.WeightedCategoryCrossEntropy() ] return tl.Serial(tl.Branch(nsp_loss, mlm_loss), tl.Add())
def __init__(self, residual_layers): self.compute_residual = tl.Serial( # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(residual_layers, [], []), ) layers = [self.compute_residual, tl.Parallel(tl.Add(), [])] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [self.compute_residual, self.subtract_top]
def MultiplicativeSparseDense(sparsity, d_input, d_output=None, use_bias=True, use_bfloat16=False): """Returns a replacement of Dense layer which uses less parameters. The layer uses number of modules equal to `sparsity`. It multiplies each dimension of the input tensor by a scalar specific to each dimension and each module separately; then it applies Dense(d_output/sparsity) to each module. Compared to standard dense layer, MultiplicativeSparseDense uses less parameters while still being able to express many interesting functions (for example a permutation). Args: sparsity: The sparsity of the layer; the output vector is divided into this number of modules. d_input: Dimensionality of input tensor. d_output: Dimensionality of output tensor; by default equal to d_input. use_bias: Whether to use bias. use_bfloat16: Whether to use bfloat16 for weights. """ assert d_output % sparsity == 0 d_module = d_output // sparsity layers = [ # Weight below is used for per-head preprocessing of an embedding. tl.Weights(init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_input], use_bfloat16=use_bfloat16), # Weight below is dense kernel, shared across heads. tl.Weights(init.GlorotUniformInitializer(), [d_input, d_module], use_bfloat16=use_bfloat16), # To save memory the per-head preprocessing and multiplying by the # kernel is done in the same einsum. tl.Fn( 'AttentionEinsum', ( lambda kernel, multiplier, embeds: # pylint: disable=g-long-lambda jnp.einsum('dx,hd,...d->...hx', kernel, multiplier, embeds))), MergeLastTwoAxes(), ] if use_bias: layers.extend([ # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_output], use_bfloat16=use_bfloat16), tl.Add(), ]) return tl.Serial(layers)
def loss(id_to_mask=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel(layers.CrossEntropyLoss(id_to_mask, has_weights), layers.L2Loss(id_to_mask, has_weights)), # Add both losses. layers.Add(), # Zero out in this test. layers.Fn(lambda x: x * 0.0), )
def loss(mask_id=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel( layers.CrossEntropyLossScalar(mask_id, has_weights), layers.L2LossScalar(mask_id, has_weights)), # Add both losses. layers.Add(), # Zero out in this test. layers.MulConstant(constant=0.0))
def loss(): """Cross-entropy loss as scalar compatible with Trax masking.""" ones = layers.Fn(lambda x: math.numpy.ones_like(x)) # pylint: disable=unnecessary-lambda return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Duplicate target-obs and target-reward and make 1 to add weights. layers.Parallel([], layers.Branch([], ones)), layers.Parallel([], [], [], [], layers.Branch([], ones)), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel(layers.CrossEntropyLoss(), layers.L2Loss()), # Add both losses. layers.Add(), # Zero out in this test. layers.Fn(lambda x: x * 0.0), )
def __init__(self, residual_layers): self.compute_residual = tl.Serial( # x1_or_y1, x2, ... tl.Select([1, 0, 1]), # x2, x1_or_y1, x2, ... tl.Parallel([], [], residual_layers), # x2, x1_or_y1, residual, ... tl.Select([2, 1, 0]), # residual, x1_or_y1, x2, ... ) self.n_preserve = self.compute_residual.n_out - 2 parallel_preserve = [[]] * self.n_preserve layers = [ self.compute_residual, tl.Parallel(tl.Add(), *parallel_preserve) ] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), *parallel_preserve) self.reverse_layers = [self.compute_residual, self.subtract_top]
def ActionInjector(mode): if inject_actions: return tl.Serial( # Input: (body output, actions). tl.Parallel( tl.Dense(inject_actions_dim), tl.Dense(inject_actions_dim), ), tl.Add(), models.PureMLP( layer_widths=(inject_actions_dim, ) * inject_actions_n_layers, out_activation=True, flatten=False, mode=mode, )) else: return []
def MultiplicativeModularSparseDense(sparsity, d_feature): """Returns a replacement of Dense layer which uses less parameters. The layer uses number of modules equal to `sparsity`. It is a combination of multiplicative dense and locally connected dense layers. Args: sparsity: The sparsity of the layer; the output vector is divided into this number of modules. d_feature: Dimensionality of input and output tensor. """ assert d_feature % sparsity == 0 d_module = d_feature // sparsity return tl.Serial( # Weight below is used for per-head preprocessing of an embedding. tl.Weights(init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_feature]), # Weight below is a kernel of multiplicative dense, shared across heads. tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]), # Weight below is a kernel of modular dense. tl.Weights( functools.partial(init.GlorotUniformInitializer(), nonreceptive_dims=[0]), [sparsity, d_module, d_module]), # To save memory the per-head preprocessing and multiplying by # kernels is done in a single einsum. tl.Fn( 'SparseDenseEinsum', ( lambda kmod, kmult, multiplier, embeds: # pylint: disable=g-long-lambda jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier, embeds))), MergeLastTwoAxes(), # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), tl.Add(), )
def ActionInjector(mode): if inject_actions: if is_discrete: encode_layer = tl.Parallel( tl.Dense(inject_actions_dim), tl.Embedding(inject_actions_dim, vocab_size=vocab_size)) else: encode_layer = tl.Parallel( tl.Dense(inject_actions_dim), tl.Dense(inject_actions_dim), ) return tl.Serial( # Input: (body output, actions). encode_layer, tl.Add(), models.PureMLP( layer_widths=(inject_actions_dim, ) * inject_actions_n_layers, out_activation=True, flatten=False, mode=mode, )) else: return []
def BERT(d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval', ): """BERT (default hparams are for bert-base-uncased).""" layer_norm_eps = 1e-12 d_head = d_model // n_heads word_embeddings = tl.Embedding(d_model, vocab_size) type_embeddings = tl.Embedding(d_model, type_vocab_size) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. tl.Parallel( word_embeddings, type_embeddings, [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)] ), tl.Add(), position_embeddings, tl.LayerNorm(epsilon=layer_norm_eps), ] encoder = [] for _ in range(n_layers): attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head, bias=True, masked=True, mode=mode) feed_forward = [ tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model) ] encoder += [ tl.Select([0, 1, 1]), # Save a copy of the mask tl.Residual(attn, AddBias()), # pylint: disable=no-value-for-parameter tl.LayerNorm(epsilon=layer_norm_eps), tl.Residual(*feed_forward), tl.LayerNorm(epsilon=layer_norm_eps), ] encoder += [tl.Select([0], n_in=2)] # Drop the mask pooler = [ tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2), tl.Dense(d_model), tl.Tanh(), ] init_checkpoint = init_checkpoint if mode == 'train' else None bert = PretrainedBERT( embeddings + encoder + pooler, init_checkpoint=init_checkpoint) if head is not None: bert = tl.Serial(bert, head()) return bert
def test_default_name(self): layer = tl.Branch(tl.Add(), DivideBy(0.5)) self.assertIn('Branch', str(layer))
def test_printing_sublayers(self): layer = tl.Branch(tl.Add(), tl.Add()) expected_result = 'Branch_in2_out2[\n Add_in2\n Add_in2\n]' self.assertEqual(expected_result, str(layer))
def LatentTransformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings(input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs)) encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers) ] encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [ _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers) ] compress_seq = tl.Serial( # input: # tok tl.Branch([], tl.PaddingMask()), # tok mask encoder, # vec mask PickFirst(), # vec_f mask tl.Select([0], n_in=2)) # vec_f latent_transition = tl.Serial( tl.Parallel([tl.Dense(d_model), tl.Relu()], [tl.Dense(d_model), tl.Relu()]), tl.Add(), tl.Residual( tl.LayerNorm(), tl.Dense(d_model), tl.Relu(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_model), )) pred_valid = tl.Serial(tl.Dense(2), Squeeze(1)) embed_tgt = tl.Serial( # Input # tok_d DropLast(mode=mode), # stok_d out_encoder, # svec_d ) decode_seq = tl.Serial( # Input: # vec_e tok_d tl.Select([1, 0, 1]), # tok_d vec_e tok_d tl.Parallel(embed_tgt, [], DropFirst()), # svec_d vec_e tok_d' ConcatDeEntoEnDe(), # vec_ed tok_d' # Decoder blocks with causal attention decoder_blocks, # vec_ed tok_d' tl.LayerNorm(), # vec_ed tok_d' DropFirst(), # vec_d tok_d' # Map to output vocab. tl.Dense(output_vocab_size), # pred_d tok_d' ) # compress_seq: n_in 1 n_out 1: add mask, encode, pick last hidden # latent_transition: n_in 2 n_out 1: s, a -> s_1 # pred_valid: n_in 1 n_out 1: s_1 -> pred_v # decode_seq: n_in 2 n_out 2: copy target, shift right, decode, output return tl.Serial( # 0 1 2 3 4 5 6 7 8 # Input: # tok_s tok_a tok_s1 r v tl.Select([0, 1, 2, 0, 1, 3, 4]), # tok_s tok_a tok_s1 tok_s tok_a r v # Encode. tl.Parallel( compress_seq, compress_seq), # vec_s vec_a tok_s1 tok_s tok_a r v tl.Branch(latent_transition, [], tl.Select( [1], n_in=2)), # vec_s1 vec_s vec_a tok_s1 tok_s tok_a r v tl.Branch(pred_valid, []), # pred_v vec_s1 vec_s vec_a tok_s1 tok_s tok_a r v # Decode. tl.Select([1, 4, 2, 5, 3, 6, 0, 8, 7]), # vec_s1 tok_s1 vec_s tok_s vec_a tok_a pred_v v r tl.Parallel(decode_seq, decode_seq, decode_seq ), # pred_s1 tok_s1 pred_s tok_s pred_a tok_a pred_v v r )
def create_hourglass_valley( rest_shorten_factors, rest_n_funnel_blocks, # pylint: disable = invalid-name current_total_pooling): assert rest_shorten_factors assert len(rest_shorten_factors) == len(rest_n_funnel_blocks) current_sf = rest_shorten_factors[0] current_n_layers = rest_n_funnel_blocks[0] shortening_layer = downsampling_fn( current_sf, d_model, is_upsampling=False, d_ff=d_ff, n_heads=n_heads, dropout=dropout, dropout_shared_axes=dropout_shared_axes, mode=mode, ff_activation=ff_activation, context_bias_layer=context_bias_layer, location_bias_layer=location_bias_layer, total_pooling=current_total_pooling, resampling_fn=attention_downsampling_fn) upsampling_layer = upsampling_fn( current_sf, d_model=d_model, is_upsampling=True, d_ff=d_ff, n_heads=n_heads, dropout=dropout, dropout_shared_axes=dropout_shared_axes, mode=mode, ff_activation=ff_activation, context_bias_layer=context_bias_layer, location_bias_layer=location_bias_layer, total_pooling=current_total_pooling, resampling_fn=attention_upsampling_fn) if len(rest_shorten_factors) > 1: # we need to go deeper again pre_stage_blocks = create_decoder_blocks( current_n_layers, current_total_pooling * current_sf, middle_attn_type) post_stage_blocks = create_decoder_blocks( current_n_layers, current_total_pooling * current_sf, middle_attn_type) return [ tl.Dup(), tl.ShiftRight(current_sf - 1, mode=mode), shortening_layer, pre_stage_blocks, *create_hourglass_valley( rest_shorten_factors[1:], rest_n_funnel_blocks[1:], current_total_pooling * current_sf), post_stage_blocks, upsampling_layer, tl.LayerNorm(), tl.Add() ] else: blocks = create_decoder_blocks(current_n_layers, current_total_pooling * current_sf, middle_attn_type) return [ tl.Dup(), tl.ShiftRight(current_sf - 1), shortening_layer, blocks, upsampling_layer, tl.LayerNorm(), tl.Add() ]