Ejemplo n.º 1
0
 def test_typo_ellipsis_fail(self):
     layer = tl.AssertShape('b..c,2')
     x = [np.ones((2, 3, 4, 5)), np.zeros((2))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 2
0
 def test_prefix_ellipsis_matching_sufix_ellipsis_pass(self):
     layer = tl.AssertShape('bb...,...bb')
     x = [np.ones((2, 2, 5, 6)), np.zeros((5, 6, 2, 2))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 3
0
 def test_short_middle_ellipsis_fail(self):
     layer = tl.AssertShape('b...c,2')
     x = [np.ones((2)), np.zeros((2))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 4
0
 def test_rank_fail(self):
     layer = tl.AssertShape('aba,ab')
     x = [np.ones((10, 5, 10)), np.ones((5, 10, 4))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 5
0
 def test_simple_pass(self):
     layer = tl.AssertShape('aba,ba')
     x = [np.ones((10, 5, 10)), np.zeros((5, 10))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 6
0
 def test_prefix_and_sufix_ellipsis_fail(self):
     layer = tl.AssertShape('...c...,2')
     x = [np.ones((2, 3, 4, 5)), np.zeros((2))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 7
0
 def test_ellipses_matching_dims_fail(self):
     layer = tl.AssertShape('...2,...8')
     x = [np.ones((1, 2, 3, 9)), np.zeros((1, 3, 3, 8))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 8
0
 def test_three_args_pass(self):
     layer = tl.AssertShape('a,b,a')
     x = [np.ones((5, )), np.zeros((2)), np.zeros((5))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 9
0
 def test_multiple_matching_dims_pass(self):
     layer = tl.AssertShape('a,b,a,ab')
     x = [np.ones((5, )), np.zeros((2)), np.zeros((5)), np.zeros((5, 2))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 10
0
 def test_square_matrix_pass(self):
     layer = tl.AssertShape('aa')
     x = np.ones((3, 3))
     y = layer(x)
     self.assertEqual(y.tolist(), x.tolist())
Ejemplo n.º 11
0
 def test_vector_scalar_pass(self):
     layer = tl.AssertShape('a,')
     x = [np.ones((5, )), np.zeros(())]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 12
0
 def test_scalar_pass(self):
     layer = tl.AssertShape('')
     x = np.ones(())
     y = layer(x)
     self.assertEqual(y.tolist(), x.tolist())
Ejemplo n.º 13
0
 def test_single_arg_pass(self):
     layer = tl.AssertShape('a')
     x = np.ones((5, ))
     y = layer(x)
     self.assertEqual(y.tolist(), x.tolist())
Ejemplo n.º 14
0
 def test_same_shapes_pass(self):
     layer = tl.AssertShape('aba,ba')
     x = [np.ones((5, 5, 5)), np.zeros((5, 5))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 15
0
 def test_ellipsis_matching_ellipsis_fail(self):
     layer = tl.AssertShape('...a,...b')
     x = [np.ones((1, 2, 3, 7)), np.zeros((1, 2, 8))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 16
0
 def test_numeric_dims_pass(self):
     layer = tl.AssertShape('23,1,93')
     x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 17
0
 def test_ellipsis_numeric_pass(self):
     layer = tl.AssertShape('...22,...3')
     x = [np.ones((1, 2, 3, 2, 2)), np.zeros((1, 2, 3, 3))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 18
0
 def test_numeric_dims_fail(self):
     layer = tl.AssertShape('24,1,93')
     x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 19
0
 def test_ellipsis_too_few_dims_fail(self):
     layer = tl.AssertShape('...abc,2')
     x = [np.ones((4, 5)), np.zeros((2))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 20
0
 def test_ellipsis_prefix_pass(self):
     layer = tl.AssertShape('...bc,abc')
     x = [np.ones((5, 5, 2, 3)), np.zeros((1, 2, 3))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 21
0
 def test_dims_matching_fail(self):
     layer = tl.AssertShape('aba,ab')
     x = [np.ones((10, 5, 10)), np.ones((5, 8))]
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 22
0
 def test_ellipsis_matching_ellipsis_pass(self):
     layer = tl.AssertShape('...bc,...bc')
     x = [np.ones((1, 2, 3)), np.zeros((1, 2, 3))]
     y = layer(x)
     self.assertEqual(y, x)
Ejemplo n.º 23
0
 def test_square_matrix_fail(self):
     layer = tl.AssertShape('aa')
     x = np.ones((10, 5))
     with self.assertRaises(tl.LayerError):
         layer(x)
Ejemplo n.º 24
0
def SkippingTransformerLM(vocab_size,
                          d_model=512,
                          d_ff=2048,
                          n_layers=6,
                          n_heads=8,
                          dropout=0.1,
                          max_len=2048,
                          mode='train',
                          ff_activation=tl.Relu,
                          skip_fraction=0.4):
    """Returns a Skipping Transformer language model.

  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference
    ff_activation: the non-linearity in feed-forward layer
    skip_fraction: fraction of times to skip some layers

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
    embedder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode),
    ]

    @assert_shape('...sd,->...sd,')
    def ConditionedBlock(current_layer_num):
        return tl.Serial(
            # stack: embedding, n_layers_to_keep
            tl.Select([1, 0,
                       1]),  # n_layers_to_keep, embedding, n_layers_to_keep
            tl.Cond(
                # if n_layers_to_keep > current_layer_num
                LargerThan(float(current_layer_num)),
                # then: run block
                tl.Serial(
                    transformer._DecoderBlock(  # pylint: disable=g-complex-comprehension,protected-access
                        d_model, d_ff, n_heads, dropout, [], mode,
                        ff_activation)),
                # else: run noop
                tl.Serial())
            # stack: embedding, n_layers_to_keep
        )

    if mode == 'train':
        if skip_fraction == 0.0:
            minimum_layers = float(n_layers)
            maximum_layers = float(n_layers)
        else:
            minimum_layers = 0.0
            maximum_layers = float(n_layers) / skip_fraction
    else:
        minimum_layers = maximum_layers = float(n_layers)

    return tl.Serial(
        tl.ShiftRight(mode=mode),
        embedder,
        # stack: embedding
        tl.RandomUniform(minimum_layers, maximum_layers, sync=True),
        # stack: n_layers_to_keep, embedding
        tl.Swap(),
        # stack: embedding, n_layers_to_keep
        [ConditionedBlock(i) for i in range(n_layers)],
        # stack: embedding, n_layers_to_keep
        tl.AssertShape('...sd,'),
        tl.Select([0], n_in=2),  # stack: embedding
        tl.AssertShape('...sd'),
        tl.LayerNorm(),
        tl.Dense(vocab_size),
    )