def RandomLayer(layer_a, layer_b, prob_a): """Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.""" condition = tl.Serial( tl.RandomUniform(), tl.Fn('SmallerThan', lambda x: x < prob_a) ) return tl.Cond(condition, layer_a, layer_b)
def test_simple_range(self): layer = tl.RandomUniform(1., 2., shape=(1000, )) layer.init(()) y = layer(()) self.assertEqual(y.shape, (1000, )) self.assertBetween(min(y.tolist()), 1., 2.) self.assertBetween(max(y.tolist()), 1., 2.) self.assertBetween(1.5, min(y.tolist()), max(y.tolist()))
def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding tl.RandomUniform(0., 1, sync=True), # stack: random_uniform, embedding tl.Cond( # if random_uniform > skip_fraction LargerThan(skip_fraction[current_layer_num] if mode == 'train' else 0.0), # then: run block tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), # else: run noop tl.Serial() ) # stack: embedding )
def test_shape(self): layer = tl.RandomUniform(shape=(5, 10, 3)) layer.init(()) y = layer(()) self.assertEqual(y.shape, (5, 10, 3))
def test_simple(self): layer = tl.RandomUniform() layer.init(()) y = layer(()) self.assertEqual(y.shape, ()) self.assertBetween(y, 0.0, 1.0)
def LayerDropSkippingTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu, skip_fraction=0.4): """Returns a Skipping Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer skip_fraction: fraction of times to skip some layers Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding, n_layers_to_keep tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep tl.Cond( # if n_layers_to_keep > current_layer_num LargerThan(float(current_layer_num)), # then: run block tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), # else: run noop tl.Serial() ) # stack: embedding, n_layers_to_keep ) if mode == 'train': minimum_layers = 0.0 maximum_layers = float(n_layers) / skip_fraction else: minimum_layers = maximum_layers = float(n_layers) return tl.Serial( tl.ShiftRight(mode=mode), embedder, # stack: embedding tl.RandomUniform(minimum_layers, maximum_layers, sync=True), # stack: n_layers_to_keep, embedding tl.Swap(), # stack: embedding, n_layers_to_keep [ConditionedBlock(i) for i in range(n_layers)], # stack: embedding, n_layers_to_keep tl.Select([0], n_in=2), # stack: embedding tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def EveryOtherLayerDropTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu, skip_mode='even', skip_fraction=0.5, eval_skip_fraction=0.0): """Returns an "EveryOther" LayerDrop Transformer language model. During each training step it either runs all layers, or skips a subset of layers. This subset is the same every time, and it is specified by "skip_mode". The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer skip_mode: which layers to skip when skipping: even/odd/1half/2half. skip_fraction: fraction of times to skip layers eval_skip_fraction: fraction of times to skip layers during eval Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] if mode == 'train': pass else: skip_fraction = eval_skip_fraction skip_mode_funs = { # which layers should be skipped? 'even': (lambda num: num%2 == 0), # 0th layer is even 'odd': (lambda num: num%2 == 1), '1half': (lambda num: num < (n_layers/2)), '2half': (lambda num: num >= (n_layers/2)), } skip_mode_fun = skip_mode_funs[skip_mode] @assert_shape('...sd,->...sd,') def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding, n_layers_to_keep tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep tl.Cond( # if random() > skip_fraction OR layer not in skip_mode ... LargerThan(skip_fraction if skip_mode_fun(current_layer_num ) else 0.0), # then: run block tl.Serial( transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)) # else: noop (implicit) ) # stack: embedding, n_layers_to_keep ) return tl.Serial( tl.ShiftRight(mode=mode), embedder, # stack: embedding tl.RandomUniform(0., 1., sync=True), # stack: n_layers_to_keep, embedding tl.Swap(), # stack: embedding, n_layers_to_keep [ConditionedBlock(i) for i in range(n_layers)], # stack: embedding, n_layers_to_keep tl.Select([0], n_in=2), # stack: embedding tl.LayerNorm(), tl.Dense(vocab_size), )