def test_typo_ellipsis_fail(self): layer = tl.AssertShape('b..c,2') x = [np.ones((2, 3, 4, 5)), np.zeros((2))] with self.assertRaises(tl.LayerError): layer(x)
def test_prefix_ellipsis_matching_sufix_ellipsis_pass(self): layer = tl.AssertShape('bb...,...bb') x = [np.ones((2, 2, 5, 6)), np.zeros((5, 6, 2, 2))] y = layer(x) self.assertEqual(y, x)
def test_short_middle_ellipsis_fail(self): layer = tl.AssertShape('b...c,2') x = [np.ones((2)), np.zeros((2))] with self.assertRaises(tl.LayerError): layer(x)
def test_rank_fail(self): layer = tl.AssertShape('aba,ab') x = [np.ones((10, 5, 10)), np.ones((5, 10, 4))] with self.assertRaises(tl.LayerError): layer(x)
def test_simple_pass(self): layer = tl.AssertShape('aba,ba') x = [np.ones((10, 5, 10)), np.zeros((5, 10))] y = layer(x) self.assertEqual(y, x)
def test_prefix_and_sufix_ellipsis_fail(self): layer = tl.AssertShape('...c...,2') x = [np.ones((2, 3, 4, 5)), np.zeros((2))] with self.assertRaises(tl.LayerError): layer(x)
def test_ellipses_matching_dims_fail(self): layer = tl.AssertShape('...2,...8') x = [np.ones((1, 2, 3, 9)), np.zeros((1, 3, 3, 8))] with self.assertRaises(tl.LayerError): layer(x)
def test_three_args_pass(self): layer = tl.AssertShape('a,b,a') x = [np.ones((5, )), np.zeros((2)), np.zeros((5))] y = layer(x) self.assertEqual(y, x)
def test_multiple_matching_dims_pass(self): layer = tl.AssertShape('a,b,a,ab') x = [np.ones((5, )), np.zeros((2)), np.zeros((5)), np.zeros((5, 2))] y = layer(x) self.assertEqual(y, x)
def test_square_matrix_pass(self): layer = tl.AssertShape('aa') x = np.ones((3, 3)) y = layer(x) self.assertEqual(y.tolist(), x.tolist())
def test_vector_scalar_pass(self): layer = tl.AssertShape('a,') x = [np.ones((5, )), np.zeros(())] y = layer(x) self.assertEqual(y, x)
def test_scalar_pass(self): layer = tl.AssertShape('') x = np.ones(()) y = layer(x) self.assertEqual(y.tolist(), x.tolist())
def test_single_arg_pass(self): layer = tl.AssertShape('a') x = np.ones((5, )) y = layer(x) self.assertEqual(y.tolist(), x.tolist())
def test_same_shapes_pass(self): layer = tl.AssertShape('aba,ba') x = [np.ones((5, 5, 5)), np.zeros((5, 5))] y = layer(x) self.assertEqual(y, x)
def test_ellipsis_matching_ellipsis_fail(self): layer = tl.AssertShape('...a,...b') x = [np.ones((1, 2, 3, 7)), np.zeros((1, 2, 8))] with self.assertRaises(tl.LayerError): layer(x)
def test_numeric_dims_pass(self): layer = tl.AssertShape('23,1,93') x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] y = layer(x) self.assertEqual(y, x)
def test_ellipsis_numeric_pass(self): layer = tl.AssertShape('...22,...3') x = [np.ones((1, 2, 3, 2, 2)), np.zeros((1, 2, 3, 3))] y = layer(x) self.assertEqual(y, x)
def test_numeric_dims_fail(self): layer = tl.AssertShape('24,1,93') x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] with self.assertRaises(tl.LayerError): layer(x)
def test_ellipsis_too_few_dims_fail(self): layer = tl.AssertShape('...abc,2') x = [np.ones((4, 5)), np.zeros((2))] with self.assertRaises(tl.LayerError): layer(x)
def test_ellipsis_prefix_pass(self): layer = tl.AssertShape('...bc,abc') x = [np.ones((5, 5, 2, 3)), np.zeros((1, 2, 3))] y = layer(x) self.assertEqual(y, x)
def test_dims_matching_fail(self): layer = tl.AssertShape('aba,ab') x = [np.ones((10, 5, 10)), np.ones((5, 8))] with self.assertRaises(tl.LayerError): layer(x)
def test_ellipsis_matching_ellipsis_pass(self): layer = tl.AssertShape('...bc,...bc') x = [np.ones((1, 2, 3)), np.zeros((1, 2, 3))] y = layer(x) self.assertEqual(y, x)
def test_square_matrix_fail(self): layer = tl.AssertShape('aa') x = np.ones((10, 5)) with self.assertRaises(tl.LayerError): layer(x)
def SkippingTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu, skip_fraction=0.4): """Returns a Skipping Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer skip_fraction: fraction of times to skip some layers Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] @assert_shape('...sd,->...sd,') def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding, n_layers_to_keep tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep tl.Cond( # if n_layers_to_keep > current_layer_num LargerThan(float(current_layer_num)), # then: run block tl.Serial( transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), # else: run noop tl.Serial()) # stack: embedding, n_layers_to_keep ) if mode == 'train': if skip_fraction == 0.0: minimum_layers = float(n_layers) maximum_layers = float(n_layers) else: minimum_layers = 0.0 maximum_layers = float(n_layers) / skip_fraction else: minimum_layers = maximum_layers = float(n_layers) return tl.Serial( tl.ShiftRight(mode=mode), embedder, # stack: embedding tl.RandomUniform(minimum_layers, maximum_layers, sync=True), # stack: n_layers_to_keep, embedding tl.Swap(), # stack: embedding, n_layers_to_keep [ConditionedBlock(i) for i in range(n_layers)], # stack: embedding, n_layers_to_keep tl.AssertShape('...sd,'), tl.Select([0], n_in=2), # stack: embedding tl.AssertShape('...sd'), tl.LayerNorm(), tl.Dense(vocab_size), )