def test_mask_invalid_locations(self): hidden_states = self._get_hidden_states() batch_size = 1 seq_length = 8 hidden_size = 4 hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2) hid_states_1 = TFLongformerSelfAttention._mask_invalid_locations( hidden_states, 1) hid_states_2 = TFLongformerSelfAttention._mask_invalid_locations( hidden_states, 2) hid_states_3 = TFLongformerSelfAttention._mask_invalid_locations( hidden_states[:, :, :, :3], 2) hid_states_4 = TFLongformerSelfAttention._mask_invalid_locations( hidden_states[:, :, 2:, :], 2) self.assertTrue( tf.math.reduce_sum( tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8) self.assertTrue( tf.math.reduce_sum( tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24) self.assertTrue( tf.math.reduce_sum( tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24) self.assertTrue( tf.math.reduce_sum( tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
def test_diagonalize(self): hidden_states = self._get_hidden_states() hidden_states = tf.reshape( hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4 chunked_hidden_states = TFLongformerSelfAttention._chunk( hidden_states, window_overlap=2) window_overlap_size = shape_list(chunked_hidden_states)[2] self.assertTrue(window_overlap_size == 4) padded_hidden_states = TFLongformerSelfAttention._pad_and_diagonalize( chunked_hidden_states) self.assertTrue( shape_list(padded_hidden_states)[-1] == shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1) # first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000] tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], rtol=1e-3) tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3, ), dtype=tf.dtypes.float32), rtol=1e-3) # last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629] tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], rtol=1e-3) tf.debugging.assert_near(padded_hidden_states[0, 0, -1, :3], tf.zeros((3, ), dtype=tf.dtypes.float32), rtol=1e-3)
def test_chunk(self): hidden_states = self._get_hidden_states() batch_size = 1 seq_length = 8 hidden_size = 4 hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) chunked_hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2) # expected slices across chunk and seq length dim expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32) expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32) self.assertTrue(shape_list(chunked_hidden_states) == [1, 3, 4, 4]) tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3) tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
def test_pad_and_transpose_last_two_dims(self): hidden_states = self._get_hidden_states() self.assertTrue(shape_list(hidden_states), [1, 8, 4]) # pad along seq length dim paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2) padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5]) expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6) tf.debugging.assert_near( hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 )