Ejemplo n.º 1
0
    def test_permutation_input_uni_mask(self):
        """Tests if an input, permutation and causal mask are provided."""
        seq_length = 4
        batch_size = 1
        memory_length = 0

        input_mask = np.array([[1, 1, 1, 0]])
        permutation_mask = np.array([[
            [0, 1, 1, 1],
            [1, 0, 1, 1],
            [1, 1, 0, 1],
            [1, 1, 1, 0],
        ]])

        expected_query_mask = np.array([[[[0, 0, 0, 0], [1, 0, 0, 0],
                                          [1, 1, 0, 0], [1, 1, 1, 0]]]])
        expected_content_mask = np.array([[[[1, 0, 0, 0], [1, 1, 0, 0],
                                            [1, 1, 1, 0], [1, 1, 1, 1]]]])
        query_mask, content_mask = xlnet_base._compute_attention_mask(
            input_mask=input_mask,
            permutation_mask=permutation_mask,
            attention_type="uni",
            seq_length=seq_length,
            memory_length=memory_length,
            batch_size=batch_size,
            dtype=tf.float32)

        self.assertAllClose(query_mask, expected_query_mask)
        self.assertAllClose(content_mask, expected_content_mask)
Ejemplo n.º 2
0
    def test_permutation_mask_no_input_mask(self):
        """Tests if a permutation mask is provided but not input."""
        seq_length = 2
        batch_size = 1
        memory_length = 0

        input_mask = None
        permutation_mask = np.array([
            [[1, 0], [1, 0]],
        ])

        expected_query_mask = permutation_mask[:, None, :, :]
        expected_content_mask = np.array([[[[1, 0], [1, 1]]]])

        query_mask, content_mask = xlnet_base._compute_attention_mask(
            input_mask=input_mask,
            permutation_mask=permutation_mask,
            attention_type="bi",
            seq_length=seq_length,
            memory_length=memory_length,
            batch_size=batch_size,
            dtype=tf.float32)

        self.assertAllClose(query_mask, expected_query_mask)
        self.assertAllClose(content_mask, expected_content_mask)
Ejemplo n.º 3
0
    def test_input_mask_no_permutation(self):
        """Tests if an input mask is provided but not permutation.

    In the case that only one of input mask or permutation mask is provided
    and the attention type is bidirectional, the query mask should be
    a broadcasted version of the provided mask.

    Content mask should be a broadcasted version of the query mask, where the
    diagonal is 0s.

    """
        seq_length = 4
        batch_size = 1
        memory_length = 0

        input_mask = np.array([[1, 1, 0, 0]])
        permutation_mask = None

        expected_query_mask = input_mask[None, None, :, :]
        expected_content_mask = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0],
                                            [1, 1, 1, 0], [1, 1, 0, 1]]]])

        query_mask, content_mask = xlnet_base._compute_attention_mask(
            input_mask=input_mask,
            permutation_mask=permutation_mask,
            attention_type="bi",
            seq_length=seq_length,
            memory_length=memory_length,
            batch_size=batch_size,
            dtype=tf.float32)

        self.assertAllClose(query_mask, expected_query_mask)
        self.assertAllClose(content_mask, expected_content_mask)
Ejemplo n.º 4
0
    def test_compute_attention_mask_smoke(self, use_input_mask,
                                          use_permutation_mask, attention_type,
                                          memory_length):
        """Tests coverage and functionality for different configurations."""
        batch_size = 2
        seq_length = 8
        if use_input_mask:
            input_mask = tf.zeros(shape=(batch_size, seq_length))
        else:
            input_mask = None
        if use_permutation_mask:
            permutation_mask = tf.zeros(shape=(batch_size, seq_length,
                                               seq_length))
        else:
            permutation_mask = None
        _, content_mask = xlnet_base._compute_attention_mask(
            input_mask=input_mask,
            permutation_mask=permutation_mask,
            attention_type=attention_type,
            seq_length=seq_length,
            memory_length=memory_length,
            batch_size=batch_size,
            dtype=tf.float32)

        expected_mask_shape = (batch_size, 1, seq_length,
                               seq_length + memory_length)
        if use_input_mask or use_permutation_mask:
            self.assertEqual(content_mask.shape, expected_mask_shape)
Ejemplo n.º 5
0
 def test_no_input_masks(self):
     query_mask, content_mask = xlnet_base._compute_attention_mask(
         input_mask=None,
         permutation_mask=None,
         attention_type="uni",
         seq_length=8,
         memory_length=2,
         batch_size=2,
         dtype=tf.float32)
     self.assertIsNone(query_mask)
     self.assertIsNone(content_mask)