Beispiel #1
0
    def test_decoder(self):
        max_decode_len = 10
        config = t5.T5TransformerParams(
            num_layers=2,
            d_model=4,
            d_kv=3,
            num_heads=4,
            d_ff=16,
            vocab_size=10,
            vocab_embeddings_initializer=tf.keras.initializers.Ones(),
            relative_embeddings_initializer=tf.keras.initializers.Ones())
        decoder = t5.Decoder(config)
        batch_size = 4
        targets = tf.zeros((4, 8), dtype=tf.int32)
        encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32)
        logits, cache = decoder(targets, encoded)
        self.assertEqual(logits.shape, (4, 8, config.vocab_size))

        cache = {}
        cache[0] = _create_cache(batch_size, max_decode_len, config.num_heads,
                                 config.d_kv)
        cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads,
                                 config.d_kv)
        targets = tf.zeros((4, 1), dtype=tf.int32)
        logits, cache = decoder(targets,
                                encoded,
                                decode_position=2,
                                cache=cache,
                                decode=True,
                                max_decode_len=max_decode_len)
        self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size))
        for entry in cache.values():
            for tensor in entry.values():
                self.assertNotAllEqual(tensor.numpy()[:, 2, :, :], 0.0)
Beispiel #2
0
    def test_transformer_with_dense(self, ffn_activations,
                                    logits_via_embedding, expect_num_variables,
                                    layer_sharing, dtype):
        max_decode_len = 10
        config = t5.T5TransformerParams(
            num_layers=1,
            d_model=8,
            d_kv=4,
            num_heads=4,
            d_ff=32,
            vocab_size=10,
            shared_embedding=True,
            layer_sharing=layer_sharing,
            ffn_activations=ffn_activations,
            logits_via_embedding=logits_via_embedding)
        transformer = t5.T5Transformer(config, compute_dtype=dtype)
        self.assertLen(transformer.trainable_variables, expect_num_variables)
        inputs = tf.convert_to_tensor(
            np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
        segments = tf.convert_to_tensor(
            np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))

        dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8),
                                            dtype=dtype)
        dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
        outputs = transformer(encoder_input_tokens=inputs,
                              encoder_dense_inputs=dense_inputs,
                              decoder_input_tokens=inputs,
                              decoder_target_tokens=inputs,
                              encoder_segment_ids=segments,
                              encoder_dense_segment_ids=dense_segments,
                              decoder_segment_ids=segments)
        cache = {}
        batch_size = 2
        cache[0] = _create_cache(batch_size,
                                 max_decode_len,
                                 config.num_heads,
                                 config.d_kv,
                                 dtype=dtype)
        outputs = transformer.decode(encoder_input_tokens=inputs,
                                     encoder_dense_inputs=dense_inputs,
                                     encoded=outputs["encoded"],
                                     decoder_target_tokens=tf.ones(
                                         (batch_size, 1), dtype=tf.int32),
                                     decode_position=1,
                                     decode=True,
                                     max_decode_len=max_decode_len,
                                     cache=cache)
        self.assertEqual(outputs["logits"].shape,
                         (batch_size, 1, config.vocab_size))
        for v in transformer.trainable_variables:
            print(v.name, v.shape)
            self.assertEqual(v.dtype, tf.float32)
Beispiel #3
0
 def test_encoder_only_dense(self, dtype):
     config = t5.T5TransformerParams(
         num_layers=2,
         d_model=4,
         d_kv=3,
         num_heads=4,
         d_ff=16,
         vocab_size=10,
         vocab_embeddings_initializer=tf.keras.initializers.Ones(),
         relative_embeddings_initializer=tf.keras.initializers.Ones())
     encoder = t5.Encoder(config, compute_dtype=dtype)
     encoded = encoder(dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
     self.assertEqual(encoded.shape, (4, 2, config.d_model))