示例#1
0
    def test_residual_block_with_relative_attention(self,
                                                    use_pre_activation_order):
        np.random.seed(1234)

        batch_size = 2
        seq_len = 4
        hidden_size = 10
        inputs = tf.constant(
            np.random.normal(size=[batch_size, seq_len, hidden_size]),
            tf.float32)

        att_mask = tf.stack([
            # Force each element in the first example to only attend to itself.
            tf.eye(seq_len, dtype=tf.int32),
            # The second example can attend everywhere.
            tf.ones([seq_len, seq_len], dtype=tf.int32)
        ])

        inner_layer = etc_layers.RelativeAttention(
            hidden_size=hidden_size,
            num_heads=2,
            relative_vocab_size=2,
            initializer=tf.keras.initializers.Identity())

        residual_block = etc_layers.ResidualBlock(
            inner_layer=inner_layer,
            normalization_layer=tf.keras.layers.Lambda(lambda x: x),
            dropout_probability=0.0,
            use_pre_activation_order=use_pre_activation_order)

        relative_att_ids1 = tf.zeros([batch_size, seq_len, seq_len],
                                     dtype=tf.int32)
        result1 = residual_block(inputs,
                                 att_mask=att_mask,
                                 relative_att_ids=relative_att_ids1)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAllClose(2 * inputs[0], result1[0])

        relative_att_ids2 = tf.tile([[[0, 1, 0, 1]]], [batch_size, seq_len, 1])
        result2 = residual_block(inputs,
                                 att_mask=att_mask,
                                 relative_att_ids=relative_att_ids2)

        self.assertAllClose(result1[0], result2[0])
        self.assertNotAllClose(result1[1], result2[1])
示例#2
0
    def test_relative_attention_self_attention(self, use_one_hot_lookup):
        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 3
        seq_len = 16
        num_heads = 5
        input_hidden_size = 11
        output_hidden_size = 12
        total_key_size = 10
        total_value_size = 15
        relative_vocab_size = 21

        inputs = tf.constant(
            np.random.normal(size=[batch_size, seq_len, input_hidden_size]),
            tf.float32)
        att_mask = tf.constant(
            np.random.binomial(n=1, p=0.9, size=[batch_size, seq_len,
                                                 seq_len]))
        relative_att_ids = tf.constant(
            np.random.randint(relative_vocab_size,
                              size=[batch_size, seq_len, seq_len]))

        layer = etc_layers.RelativeAttention(
            hidden_size=output_hidden_size,
            num_heads=num_heads,
            total_key_size=total_key_size,
            total_value_size=total_value_size,
            relative_vocab_size=relative_vocab_size,
            use_one_hot_lookup=use_one_hot_lookup)

        result1 = layer(inputs,
                        att_mask=att_mask,
                        relative_att_ids=relative_att_ids)
        self.assertAllEqual([batch_size, seq_len, output_hidden_size],
                            result1.shape)

        result2 = layer(from_seq=inputs,
                        to_seq=inputs,
                        att_mask=att_mask,
                        relative_att_ids=relative_att_ids)
        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAllEqual(result1, result2)
示例#3
0
    def test_relative_attention(self, use_one_hot_lookup):
        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 3
        from_seq_len = 16
        to_seq_len = 17
        num_heads = 5
        from_hidden_size = 11
        to_hidden_size = 12
        output_hidden_size = 13
        total_key_size = 10
        total_value_size = 15
        relative_vocab_size = 21

        from_seq = tf.random.normal(
            [batch_size, from_seq_len, from_hidden_size])
        to_seq = tf.random.normal([batch_size, to_seq_len, to_hidden_size])
        att_mask = tf.constant(
            np.random.binomial(n=1,
                               p=0.9,
                               size=[batch_size, from_seq_len, to_seq_len]))
        relative_att_ids = tf.random.uniform(
            [batch_size, from_seq_len, to_seq_len],
            maxval=relative_vocab_size,
            dtype=tf.int32)

        layer = etc_layers.RelativeAttention(
            hidden_size=output_hidden_size,
            num_heads=num_heads,
            total_key_size=total_key_size,
            total_value_size=total_value_size,
            relative_vocab_size=relative_vocab_size,
            use_one_hot_lookup=use_one_hot_lookup)

        result = layer(from_seq=from_seq,
                       to_seq=to_seq,
                       att_mask=att_mask,
                       relative_att_ids=relative_att_ids)
        self.assertAllEqual([batch_size, from_seq_len, output_hidden_size],
                            result.shape)
示例#4
0
    def test_fused_global_local_attention_special_case_equivalence(self):
        # To test for correctness, we make sure the output is equivalent to
        # standard attention in the special case where `local_radius` covers the
        # entire long sequence length and projection weights are shared.
        # For simplicity, we don't use attention masks or relative attention ids
        # in this test.

        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 3
        long_seq_len = 12
        global_seq_len = 6
        hidden_size = 10
        num_heads = 5
        local_radius = 15  # Must be >= `long_seq_len - 1` to remove sparsity.
        # relative_vocab_size = 9

        long_input = tf.constant(
            np.random.normal(size=[batch_size, long_seq_len, hidden_size]))
        global_input = tf.constant(
            np.random.normal(size=[batch_size, global_seq_len, hidden_size]))

        fused_att_layer = etc_layers.FusedGlobalLocalAttention(
            long_hidden_size=hidden_size,
            global_hidden_size=hidden_size,
            num_heads=num_heads,
            local_radius=local_radius,
            share_qkv_projections=True,
            share_att_output_projection=True)

        long_output, global_output = fused_att_layer(
            long_input, global_input, att_implementation='sparse')

        # [batch_size, long_seq_len + global_seq_len, hidden_size]
        fused_output = tf.concat([long_output, global_output], axis=1)

        # Create concatenated input for standard attention.

        # [batch_size, long_seq_len + global_seq_len, hidden_size]
        concat_input = tf.concat([long_input, global_input], axis=1)

        standard_att_layer = etc_layers.RelativeAttention(
            hidden_size=hidden_size,
            num_heads=num_heads,
            query_projection=fused_att_layer.long_query_projection,
            key_projection=fused_att_layer.l2l_key_projection,
            value_projection=fused_att_layer.l2l_value_projection,
            output_projection=fused_att_layer.long_output_projection)

        expected_output = standard_att_layer(concat_input)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAllClose(expected_output, fused_output)

        # Make sure 'full' att_implementation gives the same output.
        long_output_full_att, global_output_full_att = fused_att_layer(
            long_input, global_input, att_implementation='full')

        self.assertAllClose(long_output, long_output_full_att)
        self.assertAllClose(global_output, global_output_full_att)
示例#5
0
    def test_relative_attention_shared_sublayers(self, use_one_hot_lookup):
        tf.compat.v1.random.set_random_seed(1234)
        np.random.seed(1234)

        batch_size = 3
        from_seq_len = 16
        to_seq_len = 17
        num_heads = 5
        from_hidden_size = 11
        to_hidden_size = 12
        output_hidden_size = 13
        total_key_size = 10
        total_value_size = 15
        relative_vocab_size = 9

        from_seq = tf.constant(
            np.random.random(
                size=[batch_size, from_seq_len, from_hidden_size]))
        to_seq = tf.constant(
            np.random.random(size=[batch_size, to_seq_len, to_hidden_size]))
        att_mask = tf.constant(
            np.random.binomial(n=1,
                               p=0.9,
                               size=[batch_size, from_seq_len, to_seq_len]))

        layer = etc_layers.RelativeAttention(
            hidden_size=output_hidden_size,
            num_heads=num_heads,
            total_key_size=total_key_size,
            total_value_size=total_value_size,
            relative_vocab_size=relative_vocab_size,
            use_one_hot_lookup=use_one_hot_lookup)

        sharing_layer = etc_layers.RelativeAttention(
            hidden_size=output_hidden_size,
            num_heads=num_heads,
            total_key_size=total_key_size,
            total_value_size=total_value_size,
            query_projection=layer.query_projection,
            key_projection=layer.key_projection,
            value_projection=layer.value_projection,
            qkv_relative_attention=layer.qkv_relative_attention,
            output_projection=layer.output_projection)

        different_layer = etc_layers.RelativeAttention(
            hidden_size=output_hidden_size,
            num_heads=num_heads,
            total_key_size=total_key_size,
            total_value_size=total_value_size,
            relative_vocab_size=relative_vocab_size,
            use_one_hot_lookup=use_one_hot_lookup)

        result1 = layer(from_seq=from_seq,
                        to_seq=to_seq,
                        att_mask=att_mask,
                        relative_att_ids=None)

        result2 = sharing_layer(from_seq=from_seq,
                                to_seq=to_seq,
                                att_mask=att_mask,
                                relative_att_ids=None)

        result3 = different_layer(from_seq=from_seq,
                                  to_seq=to_seq,
                                  att_mask=att_mask,
                                  relative_att_ids=None)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAllEqual(result1, result2)
        self.assertNotAllClose(result1, result3)