def test_encdec_block(self): batch_size = 2 from_seq_length = 5 to_seq_length = 3 d_model = 4 l = t5.EncDecoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo") pos_embed = t5.RelativePositionEmbedding( num_heads=2, bidirectional=True, embeddings_initializer=tf.keras.initializers.Ones(), name="bar") encoder_decoder_mask = t5.make_attention_mask( tf.ones((batch_size, from_seq_length)), tf.ones((batch_size, to_seq_length))) position_bias = pos_embed(from_seq_length, from_seq_length) inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32) encoder_hidden_states = tf.ones((batch_size, to_seq_length, d_model), dtype=tf.float32) outputs = l(inputs, encoder_hidden_states, encoder_decoder_mask=encoder_decoder_mask, position_bias=position_bias) self.assertEqual(outputs[0].shape, (batch_size, from_seq_length, d_model))
def test_relative_position(self, dtype): l = t5.RelativePositionEmbedding( num_heads=4, bidirectional=False, embeddings_initializer=tf.keras.initializers.Ones(), compute_dtype=dtype, name="foo") self.assertEqual(l(4, 2).shape, (1, 4, 4, 2)) l = t5.RelativePositionEmbedding( num_heads=4, bidirectional=True, embeddings_initializer=tf.keras.initializers.Ones(), compute_dtype=dtype, name="bar") outputs = l(4, 2) self.assertEqual(outputs.shape, (1, 4, 4, 2)) self.assertEqual(outputs.dtype, dtype)
def test_attention(self, distribution): num_heads, head_size = 2, 4 from_seq_length, to_seq_length = 4, 6 batch_size = 2 pos_embed = t5.RelativePositionEmbedding( num_heads=4, bidirectional=False, embeddings_initializer=tf.keras.initializers.Ones(), name="pos_embed") position_bias = pos_embed(from_seq_length, from_seq_length) l = t5.MultiHeadAttention(d_model=4, d_kv=2, num_heads=4, dropout_rate=0.1) query = tf.convert_to_tensor( np.ones((batch_size, from_seq_length, 4), dtype=np.float32)) self.assertEqual( l(query, position_bias=position_bias)["context"].shape, query.shape) kv = tf.convert_to_tensor( np.ones((batch_size, to_seq_length, 4), dtype=np.float32)) position_bias = pos_embed(from_seq_length, to_seq_length) outputs = l(query, kv=kv, position_bias=position_bias) self.assertEqual(outputs["context"].shape, query.shape) with distribution.scope(): l = t5.MultiHeadAttention(d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1) @tf.function def step(inputs): def _step_fn(inputs): cache = _create_cache(batch_size, from_seq_length, num_heads, head_size) mask = t5.make_causal_mask(tf.ones((batch_size, 1))) return l(query=inputs, mask=mask, cache=cache, decode_position=decode_position) outputs = distribution.run(_step_fn, args=(inputs, )) return tf.nest.map_structure( distribution.experimental_local_results, outputs) decode_position = 2 query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32)) local_outputs = step(query) self.assertEqual(local_outputs["context"][0].shape, (2, 1, 4)) self.assertNotEqual( np.sum(local_outputs["cache"]["key"][0][:, decode_position, ...].numpy()), 0.0)
def test_attention_layers(self, distribution): num_heads, head_size = 2, 2 from_seq_length = 4 # TPU decoding should pre-allocate the entire sequence. batch_size = 2 with distribution.scope(): pos_embed = t5.RelativePositionEmbedding( num_heads=head_size, bidirectional=False, embeddings_initializer=tf.keras.initializers.Ones(), name="pos_embed") l = t5.SelfAttention(d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1) decode_position = 2 @tf.function def step(inputs): def _step_fn(inputs): cache = _create_cache(batch_size, from_seq_length, num_heads, head_size) mask = t5.make_causal_mask(tf.ones((batch_size, 1))) position_bias = pos_embed(from_seq_length, from_seq_length) return l(hidden_states=inputs, cache=cache, attention_mask=mask, decode_position=decode_position, position_bias=position_bias) outputs = distribution.run(_step_fn, args=(inputs, )) return tf.nest.map_structure( distribution.experimental_local_results, outputs) query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32)) local_outputs = step(query) self.assertEqual(local_outputs["layer_output"][0].shape, (2, 1, 4)) self.assertNotEqual( np.sum(local_outputs["cache"]["key"][0] [:, decode_position, :, :].numpy()), 0.0) l = t5.CrossAttention(d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1) to_seq_length = 6 query = tf.convert_to_tensor( np.ones((2, from_seq_length, 4), dtype=np.float32)) kv = tf.convert_to_tensor( np.ones((2, to_seq_length, 4), dtype=np.float32)) @tf.function def step_cross_attn(inputs): def _step_fn(inputs): query, kv = inputs mask = t5.make_attention_mask( tf.ones((batch_size, from_seq_length)), tf.ones((batch_size, to_seq_length))) return l(hidden_states=query, kv=kv, attention_mask=mask) outputs = distribution.run(_step_fn, args=(inputs, )) return tf.nest.map_structure( distribution.experimental_local_results, outputs) local_outputs = step_cross_attn((query, kv)) self.assertEqual(local_outputs["layer_output"][0].shape, (2, from_seq_length, 4))