Ejemplo n.º 1
0
    def test_mask_invalid_locations(self):
        hidden_states = self._get_hidden_states()
        batch_size = 1
        seq_length = 8
        hidden_size = 4
        hidden_states = tf.reshape(hidden_states,
                                   (batch_size, seq_length, hidden_size))
        hidden_states = TFLongformerSelfAttention._chunk(hidden_states,
                                                         window_overlap=2)

        hid_states_1 = TFLongformerSelfAttention._mask_invalid_locations(
            hidden_states, 1)
        hid_states_2 = TFLongformerSelfAttention._mask_invalid_locations(
            hidden_states, 2)
        hid_states_3 = TFLongformerSelfAttention._mask_invalid_locations(
            hidden_states[:, :, :, :3], 2)
        hid_states_4 = TFLongformerSelfAttention._mask_invalid_locations(
            hidden_states[:, :, 2:, :], 2)

        self.assertTrue(
            tf.math.reduce_sum(
                tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8)
        self.assertTrue(
            tf.math.reduce_sum(
                tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24)
        self.assertTrue(
            tf.math.reduce_sum(
                tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24)
        self.assertTrue(
            tf.math.reduce_sum(
                tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
Ejemplo n.º 2
0
    def test_diagonalize(self):
        hidden_states = self._get_hidden_states()
        hidden_states = tf.reshape(
            hidden_states, (1, 8, 4))  # set seq length = 8, hidden dim = 4
        chunked_hidden_states = TFLongformerSelfAttention._chunk(
            hidden_states, window_overlap=2)
        window_overlap_size = shape_list(chunked_hidden_states)[2]
        self.assertTrue(window_overlap_size == 4)

        padded_hidden_states = TFLongformerSelfAttention._pad_and_diagonalize(
            chunked_hidden_states)

        self.assertTrue(
            shape_list(padded_hidden_states)[-1] ==
            shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1)

        # first row => [0.4983,  2.6918, -0.0071,  1.0492, 0.0000,  0.0000,  0.0000]
        tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4],
                                 chunked_hidden_states[0, 0, 0],
                                 rtol=1e-3)
        tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:],
                                 tf.zeros((3, ), dtype=tf.dtypes.float32),
                                 rtol=1e-3)

        # last row => [0.0000,  0.0000,  0.0000, 2.0514, -1.1600,  0.5372,  0.2629]
        tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:],
                                 chunked_hidden_states[0, 0, -1],
                                 rtol=1e-3)
        tf.debugging.assert_near(padded_hidden_states[0, 0, -1, :3],
                                 tf.zeros((3, ), dtype=tf.dtypes.float32),
                                 rtol=1e-3)
Ejemplo n.º 3
0
    def test_chunk(self):
        hidden_states = self._get_hidden_states()
        batch_size = 1
        seq_length = 8
        hidden_size = 4
        hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size))

        chunked_hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)

        # expected slices across chunk and seq length dim
        expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
        expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)

        self.assertTrue(shape_list(chunked_hidden_states) == [1, 3, 4, 4])
        tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3)
        tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
Ejemplo n.º 4
0
    def test_pad_and_transpose_last_two_dims(self):
        hidden_states = self._get_hidden_states()
        self.assertTrue(shape_list(hidden_states), [1, 8, 4])

        # pad along seq length dim
        paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)

        hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
        padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
        self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])

        expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
        tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
        tf.debugging.assert_near(
            hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
        )