Example #1
0
    def test_qkv_relative_local_attention_full_att_implementation(
            self, local_radius, use_one_hot_lookup):
        # We check the validity of the `att_implementation` option
        # by confirming both internal implementations return the same output.

        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 3
        long_len = 64
        side_len = 6
        num_heads = 5
        key_size_per_head = 2
        value_size_per_head = 3
        relative_vocab_size = 7
        # Note: block_len = local_radius + 1

        queries = tf.constant(
            np.random.normal(
                size=[batch_size, long_len, num_heads, key_size_per_head]),
            tf.float32)
        keys = tf.constant(
            np.random.normal(
                size=[batch_size, long_len, num_heads, key_size_per_head]),
            tf.float32)
        values = tf.constant(
            np.random.normal(
                size=[batch_size, long_len, num_heads, value_size_per_head]),
            tf.float32)
        att_mask = tf.constant(np.random.binomial(
            n=1, p=0.8, size=[batch_size, long_len, 2 * local_radius + 1]),
                               dtype=tf.int32)
        relative_att_ids = tf.constant(np.random.randint(
            relative_vocab_size,
            size=[batch_size, long_len, 2 * local_radius + 1]),
                                       dtype=tf.int32)
        side_keys = tf.constant(
            np.random.normal(
                size=[batch_size, side_len, num_heads, key_size_per_head]),
            tf.float32)
        side_values = tf.constant(
            np.random.normal(
                size=[batch_size, side_len, num_heads, value_size_per_head]),
            tf.float32)
        side_att_mask = tf.constant(np.random.binomial(
            n=1, p=0.8, size=[batch_size, long_len, side_len]),
                                    dtype=tf.int32)
        side_relative_att_ids = tf.constant(np.random.randint(
            relative_vocab_size, size=[batch_size, long_len, side_len]),
                                            dtype=tf.int32)

        layer = etc_layers.QkvRelativeLocalAttention(
            local_radius=local_radius,
            relative_vocab_size=relative_vocab_size,
            use_one_hot_lookup=use_one_hot_lookup)

        sparse_implementation_result = layer(
            queries,
            keys,
            values,
            att_mask=att_mask,
            relative_att_ids=relative_att_ids,
            side_keys=side_keys,
            side_values=side_values,
            side_att_mask=side_att_mask,
            side_relative_att_ids=side_relative_att_ids,
            att_implementation='sparse')

        full_implementation_result = layer(
            queries,
            keys,
            values,
            att_mask=att_mask,
            relative_att_ids=relative_att_ids,
            side_keys=side_keys,
            side_values=side_values,
            side_att_mask=side_att_mask,
            side_relative_att_ids=side_relative_att_ids,
            att_implementation='full')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAllClose(sparse_implementation_result,
                            full_implementation_result)
Example #2
0
    def test_qkv_relative_local_attention(self,
                                          local_radius,
                                          use_one_hot_lookup=False,
                                          att_implementation='sparse'):
        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 2
        long_len = 64
        side_len = 6
        num_heads = 5
        key_size_per_head = 2
        value_size_per_head = 3
        relative_vocab_size = 7
        # Note: block_len = local_radius + 1

        queries = tf.constant(
            np.random.normal(
                size=[batch_size, long_len, num_heads, key_size_per_head]),
            tf.float32)
        keys = tf.constant(
            np.random.normal(
                size=[batch_size, long_len, num_heads, key_size_per_head]),
            tf.float32)
        values = tf.constant(
            np.random.normal(
                size=[batch_size, long_len, num_heads, value_size_per_head]),
            tf.float32)
        att_mask = tf.constant(
            np.random.binomial(
                n=1, p=0.9, size=[batch_size, long_len, 2 * local_radius + 1]))
        relative_att_ids = tf.constant(
            np.random.randint(
                relative_vocab_size,
                size=[batch_size, long_len, 2 * local_radius + 1]))

        side_keys = tf.constant(
            np.random.normal(
                size=[batch_size, side_len, num_heads, key_size_per_head]),
            tf.float32)
        side_values = tf.constant(
            np.random.normal(
                size=[batch_size, side_len, num_heads, value_size_per_head]),
            tf.float32)
        side_att_mask = tf.constant(
            np.random.binomial(n=1,
                               p=0.9,
                               size=[batch_size, long_len, side_len]))
        side_relative_att_ids = tf.constant(
            np.random.randint(relative_vocab_size,
                              size=[batch_size, long_len, side_len]))

        layer = etc_layers.QkvRelativeLocalAttention(
            local_radius=local_radius,
            relative_vocab_size=relative_vocab_size,
            use_one_hot_lookup=use_one_hot_lookup)

        result1 = layer(queries,
                        keys,
                        values,
                        att_mask=att_mask,
                        relative_att_ids=relative_att_ids,
                        side_keys=side_keys,
                        side_values=side_values,
                        side_att_mask=side_att_mask,
                        side_relative_att_ids=side_relative_att_ids,
                        att_implementation=att_implementation)
        self.assertAllEqual(
            [batch_size, long_len, num_heads, value_size_per_head],
            result1.shape)

        result2 = layer(queries,
                        keys,
                        values,
                        att_mask=None,
                        relative_att_ids=None,
                        side_keys=side_keys,
                        side_values=side_values,
                        side_att_mask=None,
                        side_relative_att_ids=None,
                        att_implementation=att_implementation)
        self.assertAllEqual(
            [batch_size, long_len, num_heads, value_size_per_head],
            result2.shape)

        result3 = layer(queries,
                        keys,
                        values,
                        att_mask=att_mask,
                        relative_att_ids=relative_att_ids,
                        side_keys=None,
                        side_values=None,
                        side_att_mask=None,
                        side_relative_att_ids=None,
                        att_implementation=att_implementation)
        self.assertAllEqual(
            [batch_size, long_len, num_heads, value_size_per_head],
            result3.shape)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertNotAllClose(result1, result2)
        self.assertNotAllClose(result2, result3)
        self.assertNotAllClose(result1, result3)