Ejemplo n.º 1
0
 def test_mask_position(self):
   """This tests _mask_position to see if it correctly masks target tokens."""
   sentence = tf.expand_dims(tf.range(1, 11), 0)
   expected_masked_position = tf.constant([[1, 2, 3, 4, 5, 0, 7, 8, 9, 10]],
                                          dtype=tf.int32)
   position = 5
   test_masked_position = importance.mask_position(sentence,
                                                   position,
                                                   mask_id=0)
   tf.debugging.assert_equal(expected_masked_position, test_masked_position)
Ejemplo n.º 2
0
 def test_masking_last_position(self):
   """This test is for the edge case where we mask the last token."""
   sentence = tf.expand_dims(tf.range(1, 11), 0)
   expected_last_token_masked_sentence = tf.constant(
       [[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]])
   position = 9
   test_last_token_masked_sentence = importance.mask_position(sentence,
                                                              position,
                                                              mask_id=0)
   tf.debugging.assert_equal(expected_last_token_masked_sentence,
                             test_last_token_masked_sentence)
Ejemplo n.º 3
0
 def test_mask_position_too_big_index(self):
   """This test is to ensure out of bounds (too big) positions get caught."""
   sentence = tf.expand_dims(tf.range(1, 11), 0)
   position = 11
   with self.assertRaises(AssertionError):
     importance.mask_position(sentence, position, mask_id=0)
Ejemplo n.º 4
0
 def test_mask_position_mask_id_requirement(self):
   """This test makes sure mask_position checks for the mask_id argument."""
   sentence = tf.expand_dims(tf.range(1, 11), 0)
   position = 5
   with self.assertRaises(AssertionError):
     importance.mask_position(sentence, position)
Ejemplo n.º 5
0
 def test_mask_position_negative_index(self):
   """This test ensures negative positions (while valid python) aren't used."""
   sentence = tf.expand_dims(tf.range(1, 11), 0)
   position = -1
   with self.assertRaises(AssertionError):
     importance.mask_position(sentence, position, mask_id=0)