Exemple #1
0
 def test_encdec_block(self):
     batch_size = 2
     from_seq_length = 5
     to_seq_length = 3
     d_model = 4
     l = t5.EncDecoderBlock(d_model=4,
                            d_kv=3,
                            num_heads=2,
                            d_ff=8,
                            name="foo")
     pos_embed = t5.RelativePositionEmbedding(
         num_heads=2,
         bidirectional=True,
         embeddings_initializer=tf.keras.initializers.Ones(),
         name="bar")
     encoder_decoder_mask = t5.make_attention_mask(
         tf.ones((batch_size, from_seq_length)),
         tf.ones((batch_size, to_seq_length)))
     position_bias = pos_embed(from_seq_length, from_seq_length)
     inputs = tf.ones((batch_size, from_seq_length, d_model),
                      dtype=tf.float32)
     encoder_hidden_states = tf.ones((batch_size, to_seq_length, d_model),
                                     dtype=tf.float32)
     outputs = l(inputs,
                 encoder_hidden_states,
                 encoder_decoder_mask=encoder_decoder_mask,
                 position_bias=position_bias)
     self.assertEqual(outputs[0].shape,
                      (batch_size, from_seq_length, d_model))
Exemple #2
0
 def test_relative_position(self, dtype):
     l = t5.RelativePositionEmbedding(
         num_heads=4,
         bidirectional=False,
         embeddings_initializer=tf.keras.initializers.Ones(),
         compute_dtype=dtype,
         name="foo")
     self.assertEqual(l(4, 2).shape, (1, 4, 4, 2))
     l = t5.RelativePositionEmbedding(
         num_heads=4,
         bidirectional=True,
         embeddings_initializer=tf.keras.initializers.Ones(),
         compute_dtype=dtype,
         name="bar")
     outputs = l(4, 2)
     self.assertEqual(outputs.shape, (1, 4, 4, 2))
     self.assertEqual(outputs.dtype, dtype)
Exemple #3
0
    def test_attention(self, distribution):
        num_heads, head_size = 2, 4
        from_seq_length, to_seq_length = 4, 6
        batch_size = 2
        pos_embed = t5.RelativePositionEmbedding(
            num_heads=4,
            bidirectional=False,
            embeddings_initializer=tf.keras.initializers.Ones(),
            name="pos_embed")
        position_bias = pos_embed(from_seq_length, from_seq_length)
        l = t5.MultiHeadAttention(d_model=4,
                                  d_kv=2,
                                  num_heads=4,
                                  dropout_rate=0.1)
        query = tf.convert_to_tensor(
            np.ones((batch_size, from_seq_length, 4), dtype=np.float32))
        self.assertEqual(
            l(query, position_bias=position_bias)["context"].shape,
            query.shape)
        kv = tf.convert_to_tensor(
            np.ones((batch_size, to_seq_length, 4), dtype=np.float32))
        position_bias = pos_embed(from_seq_length, to_seq_length)
        outputs = l(query, kv=kv, position_bias=position_bias)
        self.assertEqual(outputs["context"].shape, query.shape)

        with distribution.scope():
            l = t5.MultiHeadAttention(d_model=4,
                                      d_kv=head_size,
                                      num_heads=num_heads,
                                      dropout_rate=0.1)

            @tf.function
            def step(inputs):
                def _step_fn(inputs):
                    cache = _create_cache(batch_size, from_seq_length,
                                          num_heads, head_size)
                    mask = t5.make_causal_mask(tf.ones((batch_size, 1)))
                    return l(query=inputs,
                             mask=mask,
                             cache=cache,
                             decode_position=decode_position)

                outputs = distribution.run(_step_fn, args=(inputs, ))
                return tf.nest.map_structure(
                    distribution.experimental_local_results, outputs)

            decode_position = 2
            query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32))
            local_outputs = step(query)
            self.assertEqual(local_outputs["context"][0].shape, (2, 1, 4))
            self.assertNotEqual(
                np.sum(local_outputs["cache"]["key"][0][:, decode_position,
                                                        ...].numpy()), 0.0)
Exemple #4
0
    def test_attention_layers(self, distribution):
        num_heads, head_size = 2, 2
        from_seq_length = 4
        # TPU decoding should pre-allocate the entire sequence.
        batch_size = 2
        with distribution.scope():
            pos_embed = t5.RelativePositionEmbedding(
                num_heads=head_size,
                bidirectional=False,
                embeddings_initializer=tf.keras.initializers.Ones(),
                name="pos_embed")
            l = t5.SelfAttention(d_model=4,
                                 d_kv=head_size,
                                 num_heads=num_heads,
                                 dropout_rate=0.1)
            decode_position = 2

            @tf.function
            def step(inputs):
                def _step_fn(inputs):
                    cache = _create_cache(batch_size, from_seq_length,
                                          num_heads, head_size)
                    mask = t5.make_causal_mask(tf.ones((batch_size, 1)))
                    position_bias = pos_embed(from_seq_length, from_seq_length)
                    return l(hidden_states=inputs,
                             cache=cache,
                             attention_mask=mask,
                             decode_position=decode_position,
                             position_bias=position_bias)

                outputs = distribution.run(_step_fn, args=(inputs, ))
                return tf.nest.map_structure(
                    distribution.experimental_local_results, outputs)

            query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32))
            local_outputs = step(query)
            self.assertEqual(local_outputs["layer_output"][0].shape, (2, 1, 4))
            self.assertNotEqual(
                np.sum(local_outputs["cache"]["key"][0]
                       [:, decode_position, :, :].numpy()), 0.0)

            l = t5.CrossAttention(d_model=4,
                                  d_kv=head_size,
                                  num_heads=num_heads,
                                  dropout_rate=0.1)
            to_seq_length = 6
            query = tf.convert_to_tensor(
                np.ones((2, from_seq_length, 4), dtype=np.float32))
            kv = tf.convert_to_tensor(
                np.ones((2, to_seq_length, 4), dtype=np.float32))

            @tf.function
            def step_cross_attn(inputs):
                def _step_fn(inputs):
                    query, kv = inputs
                    mask = t5.make_attention_mask(
                        tf.ones((batch_size, from_seq_length)),
                        tf.ones((batch_size, to_seq_length)))
                    return l(hidden_states=query, kv=kv, attention_mask=mask)

                outputs = distribution.run(_step_fn, args=(inputs, ))
                return tf.nest.map_structure(
                    distribution.experimental_local_results, outputs)

            local_outputs = step_cross_attn((query, kv))
            self.assertEqual(local_outputs["layer_output"][0].shape,
                             (2, from_seq_length, 4))