Example #1
0
    def test_fused_global_local_attention_shared_sublayers(
            self,
            share_kv_projections=False,
            share_qkv_projections=False,
            share_att_output_projection=False):

        hidden_size = 10

        layer = etc_layers.FusedGlobalLocalAttention(
            long_hidden_size=hidden_size,
            global_hidden_size=hidden_size,
            num_heads=5,
            local_radius=7,
            relative_vocab_size=9,
            share_kv_projections=share_kv_projections,
            share_qkv_projections=share_qkv_projections,
            share_att_output_projection=share_att_output_projection)
        # Run layer to make sure all variables are built.
        layer(long_input=tf.ones([1, 1, hidden_size]),
              global_input=tf.ones([1, 1, hidden_size]))

        if share_qkv_projections:
            self.assertIs(layer.long_query_projection,
                          layer.global_query_projection)
            self.assert_all_identical(layer.l2l_key_projection,
                                      layer.l2g_key_projection,
                                      layer.g2g_key_projection,
                                      layer.g2l_key_projection)
            self.assert_all_identical(layer.l2l_value_projection,
                                      layer.l2g_value_projection,
                                      layer.g2g_value_projection,
                                      layer.g2l_value_projection)
        elif share_kv_projections:
            self.assertIsNot(layer.long_query_projection,
                             layer.global_query_projection)
            self.assertIs(layer.l2l_key_projection, layer.l2g_key_projection)
            self.assertIs(layer.g2g_key_projection, layer.g2l_key_projection)
            self.assertIsNot(layer.l2l_key_projection,
                             layer.g2g_key_projection)
            self.assertIs(layer.l2l_value_projection,
                          layer.l2g_value_projection)
            self.assertIs(layer.g2g_value_projection,
                          layer.g2l_value_projection)
            self.assertIsNot(layer.l2l_value_projection,
                             layer.g2g_value_projection)
        else:
            self.assertIsNot(layer.long_query_projection,
                             layer.global_query_projection)
            self.assertIsNot(layer.l2l_key_projection,
                             layer.l2g_key_projection)
            self.assertIsNot(layer.l2l_key_projection,
                             layer.g2g_key_projection)
            self.assertIsNot(layer.l2l_value_projection,
                             layer.l2g_value_projection)
            self.assertIsNot(layer.l2l_value_projection,
                             layer.g2g_value_projection)

        self.assertIsNot(layer.long_query_projection, layer.l2l_key_projection)
        self.assertIsNot(layer.long_query_projection,
                         layer.l2l_value_projection)
        self.assertIsNot(layer.l2l_key_projection, layer.l2l_value_projection)

        if share_att_output_projection:
            self.assertIs(layer.long_output_projection,
                          layer.global_output_projection)
        else:
            self.assertIsNot(layer.long_output_projection,
                             layer.global_output_projection)
Example #2
0
    def test_fused_global_local_attention_custom_total_att_size(self):
        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 3
        long_seq_len = 12
        global_seq_len = 6
        hidden_size = 11
        num_heads = 5
        local_radius = 2
        total_att_size = 10
        relative_vocab_size = 9

        long_input = tf.constant(
            np.random.normal(size=[batch_size, long_seq_len, hidden_size]))
        global_input = tf.constant(
            np.random.normal(size=[batch_size, global_seq_len, hidden_size]))
        l2l_att_mask = tf.constant(
            np.random.binomial(
                n=1,
                p=0.9,
                size=[batch_size, long_seq_len, 2 * local_radius + 1]))
        g2g_att_mask = tf.constant(
            np.random.binomial(
                n=1, p=0.9, size=[batch_size, global_seq_len, global_seq_len]))
        l2g_att_mask = tf.constant(
            np.random.binomial(n=1,
                               p=0.9,
                               size=[batch_size, long_seq_len,
                                     global_seq_len]))
        g2l_att_mask = tf.constant(
            np.random.binomial(n=1,
                               p=0.9,
                               size=[batch_size, global_seq_len,
                                     long_seq_len]))
        l2l_relative_att_ids = tf.constant(
            np.random.randint(
                relative_vocab_size,
                size=[batch_size, long_seq_len, 2 * local_radius + 1]))
        g2g_relative_att_ids = tf.constant(
            np.random.randint(
                relative_vocab_size,
                size=[batch_size, global_seq_len, global_seq_len]))
        l2g_relative_att_ids = tf.constant(
            np.random.randint(relative_vocab_size,
                              size=[batch_size, long_seq_len, global_seq_len]))
        g2l_relative_att_ids = tf.constant(
            np.random.randint(relative_vocab_size,
                              size=[batch_size, global_seq_len, long_seq_len]))

        fused_att_layer = etc_layers.FusedGlobalLocalAttention(
            long_hidden_size=hidden_size,
            global_hidden_size=hidden_size,
            num_heads=num_heads,
            local_radius=local_radius,
            long_total_att_size=total_att_size,
            global_total_att_size=total_att_size,
            relative_vocab_size=relative_vocab_size,
            share_qkv_projections=True,
            share_att_output_projection=True)

        long_output, global_output = fused_att_layer(
            long_input,
            global_input,
            l2l_att_mask=l2l_att_mask,
            g2g_att_mask=g2g_att_mask,
            l2g_att_mask=l2g_att_mask,
            g2l_att_mask=g2l_att_mask,
            l2l_relative_att_ids=l2l_relative_att_ids,
            g2g_relative_att_ids=g2g_relative_att_ids,
            l2g_relative_att_ids=l2g_relative_att_ids,
            g2l_relative_att_ids=g2l_relative_att_ids)

        self.evaluate(tf.compat.v1.global_variables_initializer())

        self.assertAllEqual([batch_size, long_seq_len, hidden_size],
                            long_output.shape)
        self.assertAllEqual([batch_size, global_seq_len, hidden_size],
                            global_output.shape)
Example #3
0
    def test_fused_global_local_attention_special_case_equivalence(self):
        # To test for correctness, we make sure the output is equivalent to
        # standard attention in the special case where `local_radius` covers the
        # entire long sequence length and projection weights are shared.
        # For simplicity, we don't use attention masks or relative attention ids
        # in this test.

        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 3
        long_seq_len = 12
        global_seq_len = 6
        hidden_size = 10
        num_heads = 5
        local_radius = 15  # Must be >= `long_seq_len - 1` to remove sparsity.
        # relative_vocab_size = 9

        long_input = tf.constant(
            np.random.normal(size=[batch_size, long_seq_len, hidden_size]))
        global_input = tf.constant(
            np.random.normal(size=[batch_size, global_seq_len, hidden_size]))

        fused_att_layer = etc_layers.FusedGlobalLocalAttention(
            long_hidden_size=hidden_size,
            global_hidden_size=hidden_size,
            num_heads=num_heads,
            local_radius=local_radius,
            share_qkv_projections=True,
            share_att_output_projection=True)

        long_output, global_output = fused_att_layer(
            long_input, global_input, att_implementation='sparse')

        # [batch_size, long_seq_len + global_seq_len, hidden_size]
        fused_output = tf.concat([long_output, global_output], axis=1)

        # Create concatenated input for standard attention.

        # [batch_size, long_seq_len + global_seq_len, hidden_size]
        concat_input = tf.concat([long_input, global_input], axis=1)

        standard_att_layer = etc_layers.RelativeAttention(
            hidden_size=hidden_size,
            num_heads=num_heads,
            query_projection=fused_att_layer.long_query_projection,
            key_projection=fused_att_layer.l2l_key_projection,
            value_projection=fused_att_layer.l2l_value_projection,
            output_projection=fused_att_layer.long_output_projection)

        expected_output = standard_att_layer(concat_input)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAllClose(expected_output, fused_output)

        # Make sure 'full' att_implementation gives the same output.
        long_output_full_att, global_output_full_att = fused_att_layer(
            long_input, global_input, att_implementation='full')

        self.assertAllClose(long_output, long_output_full_att)
        self.assertAllClose(global_output, global_output_full_att)