Ejemplo n.º 1
0
 def test_transformer_forward_shape(self):
   """Run the Transformer forward and check output shape."""
   vocab_size = 16
   single_input_shape = [3, 5]
   input_shape = (tuple(single_input_shape), tuple(single_input_shape))
   model = transformer.Transformer(
       vocab_size, d_feature=32, d_feedforward=64, n_layers=2, n_heads=2)
   final_shape = tl.check_shape_agreement(
       tl.Serial(model), input_shape, integer_inputs=True)
   self.assertEqual(tuple(single_input_shape + [vocab_size]), final_shape)
Ejemplo n.º 2
0
 def _test_transformer_forward_shape(self, input_vocab_size,
                                     output_vocab_size):
   """Run the Transformer forward and check output shape."""
   single_input_shape = [3, 5]
   input_shape = (tuple(single_input_shape), tuple(single_input_shape))
   model = transformer.Transformer(
       input_vocab_size, output_vocab_size,
       d_feature=32, d_feedforward=64, n_layers=2, n_heads=2)
   final_shape = tl.check_shape_agreement(
       model, input_shape, integer_inputs=True)
   expected_shape = (tuple(single_input_shape +
                           [output_vocab_size if output_vocab_size is not None
                            else input_vocab_size]))
   self.assertEqual(expected_shape, final_shape)