예제 #1
0
 def test_encdec_block(self):
     batch_size = 2
     from_seq_length = 5
     to_seq_length = 3
     d_model = 4
     l = t5.EncDecoderBlock(d_model=4,
                            d_kv=3,
                            num_heads=2,
                            d_ff=8,
                            name="foo")
     pos_embed = t5.RelativePositionEmbedding(
         num_heads=2,
         bidirectional=True,
         embeddings_initializer=tf.keras.initializers.Ones(),
         name="bar")
     encoder_decoder_mask = t5.make_attention_mask(
         tf.ones((batch_size, from_seq_length)),
         tf.ones((batch_size, to_seq_length)))
     position_bias = pos_embed(from_seq_length, from_seq_length)
     inputs = tf.ones((batch_size, from_seq_length, d_model),
                      dtype=tf.float32)
     encoder_hidden_states = tf.ones((batch_size, to_seq_length, d_model),
                                     dtype=tf.float32)
     outputs = l(inputs,
                 encoder_hidden_states,
                 encoder_decoder_mask=encoder_decoder_mask,
                 position_bias=position_bias)
     self.assertEqual(outputs[0].shape,
                      (batch_size, from_seq_length, d_model))
예제 #2
0
 def _step_fn(inputs):
     query, kv = inputs
     mask = t5.make_attention_mask(
         tf.ones((batch_size, from_seq_length)),
         tf.ones((batch_size, to_seq_length)))
     return l(hidden_states=query, kv=kv, attention_mask=mask)