Example #1
0
 def test_create_masked_lm_predictions(self):
     tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
     rng = random.Random(123)
     for _ in range(0, 5):
         output_tokens, masked_positions, masked_labels = (
             cpd.create_masked_lm_predictions(tokens=tokens,
                                              masked_lm_prob=1.0,
                                              max_predictions_per_seq=3,
                                              vocab_words=_VOCAB_WORDS,
                                              rng=rng,
                                              do_whole_word_mask=False,
                                              max_ngram_size=None))
         self.assertEqual(len(masked_positions), 3)
         self.assertEqual(len(masked_labels), 3)
         self.assertTokens(tokens, output_tokens, masked_positions,
                           masked_labels)
 def test_create_masked_lm_predictions_ngram(self):
   tokens = ["[CLS]"] + ["tok{}".format(i) for i in range(0, 512)] + ["[SEP]"]
   rng = random.Random(345)
   for _ in range(0, 5):
     output_tokens, masked_positions, masked_labels = (
         cpd.create_masked_lm_predictions(
             tokens=tokens,
             masked_lm_prob=1.0,
             max_predictions_per_seq=76,
             vocab_words=_VOCAB_WORDS,
             rng=rng,
             do_whole_word_mask=True,
             max_ngram_size=3))
     self.assertEqual(len(masked_positions), 76)
     self.assertEqual(len(masked_labels), 76)
     self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
 def test_create_masked_lm_predictions_whole_word(self):
   tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
   rng = random.Random(345)
   for _ in range(0, 5):
     output_tokens, masked_positions, masked_labels = (
         cpd.create_masked_lm_predictions(
             tokens=tokens,
             masked_lm_prob=1.0,
             max_predictions_per_seq=3,
             vocab_words=_VOCAB_WORDS,
             rng=rng,
             do_whole_word_mask=True,
             max_ngram_size=None))
     # since we can't get exactly three tokens without breaking a word we
     # only take two.
     self.assertEqual(len(masked_positions), 2)
     self.assertEqual(len(masked_labels), 2)
     self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
     # ensure that we took an entire word.
     self.assertIn(masked_labels, [["a", "##a"], ["b", "##b"], ["c", "##c"]])