def test_decoder(self): max_decode_len = 10 config = t5.T5TransformerParams( num_layers=2, d_model=4, d_kv=3, num_heads=4, d_ff=16, vocab_size=10, vocab_embeddings_initializer=tf.keras.initializers.Ones(), relative_embeddings_initializer=tf.keras.initializers.Ones()) decoder = t5.Decoder(config) batch_size = 4 targets = tf.zeros((4, 8), dtype=tf.int32) encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32) logits, cache = decoder(targets, encoded) self.assertEqual(logits.shape, (4, 8, config.vocab_size)) cache = {} cache[0] = _create_cache(batch_size, max_decode_len, config.num_heads, config.d_kv) cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads, config.d_kv) targets = tf.zeros((4, 1), dtype=tf.int32) logits, cache = decoder(targets, encoded, decode_position=2, cache=cache, decode=True, max_decode_len=max_decode_len) self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size)) for entry in cache.values(): for tensor in entry.values(): self.assertNotAllEqual(tensor.numpy()[:, 2, :, :], 0.0)
def test_transformer_with_dense(self, ffn_activations, logits_via_embedding, expect_num_variables, layer_sharing, dtype): max_decode_len = 10 config = t5.T5TransformerParams( num_layers=1, d_model=8, d_kv=4, num_heads=4, d_ff=32, vocab_size=10, shared_embedding=True, layer_sharing=layer_sharing, ffn_activations=ffn_activations, logits_via_embedding=logits_via_embedding) transformer = t5.T5Transformer(config, compute_dtype=dtype) self.assertLen(transformer.trainable_variables, expect_num_variables) inputs = tf.convert_to_tensor( np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]])) segments = tf.convert_to_tensor( np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]])) dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype) dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]])) outputs = transformer(encoder_input_tokens=inputs, encoder_dense_inputs=dense_inputs, decoder_input_tokens=inputs, decoder_target_tokens=inputs, encoder_segment_ids=segments, encoder_dense_segment_ids=dense_segments, decoder_segment_ids=segments) cache = {} batch_size = 2 cache[0] = _create_cache(batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype) outputs = transformer.decode(encoder_input_tokens=inputs, encoder_dense_inputs=dense_inputs, encoded=outputs["encoded"], decoder_target_tokens=tf.ones( (batch_size, 1), dtype=tf.int32), decode_position=1, decode=True, max_decode_len=max_decode_len, cache=cache) self.assertEqual(outputs["logits"].shape, (batch_size, 1, config.vocab_size)) for v in transformer.trainable_variables: print(v.name, v.shape) self.assertEqual(v.dtype, tf.float32)
def test_encoder_only_dense(self, dtype): config = t5.T5TransformerParams( num_layers=2, d_model=4, d_kv=3, num_heads=4, d_ff=16, vocab_size=10, vocab_embeddings_initializer=tf.keras.initializers.Ones(), relative_embeddings_initializer=tf.keras.initializers.Ones()) encoder = t5.Encoder(config, compute_dtype=dtype) encoded = encoder(dense_inputs=tf.ones((4, 2, 4), dtype=dtype)) self.assertEqual(encoded.shape, (4, 2, config.d_model))