def test_encdec_block(self): batch_size = 2 from_seq_length = 5 to_seq_length = 3 d_model = 4 l = t5.EncDecoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo") pos_embed = t5.RelativePositionEmbedding( num_heads=2, bidirectional=True, embeddings_initializer=tf.keras.initializers.Ones(), name="bar") encoder_decoder_mask = t5.make_attention_mask( tf.ones((batch_size, from_seq_length)), tf.ones((batch_size, to_seq_length))) position_bias = pos_embed(from_seq_length, from_seq_length) inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32) encoder_hidden_states = tf.ones((batch_size, to_seq_length, d_model), dtype=tf.float32) outputs = l(inputs, encoder_hidden_states, encoder_decoder_mask=encoder_decoder_mask, position_bias=position_bias) self.assertEqual(outputs[0].shape, (batch_size, from_seq_length, d_model))
def _step_fn(inputs): query, kv = inputs mask = t5.make_attention_mask( tf.ones((batch_size, from_seq_length)), tf.ones((batch_size, to_seq_length))) return l(hidden_states=query, kv=kv, attention_mask=mask)