Esempio n. 1
0
 def _encode_line(line: str,
                  half: bool,
                  from_end: bool = False) -> Optional[Sentence]:
     return check_sent_len(msk_sentence(
         text_encoder.encode(line.rstrip()), len(text_encoder),
         keep_prob, mask_prob, rand_prob),
                           _min_len // (2 if half else 1),
                           _max_len // (2 if half else 1),
                           from_end=from_end)
Esempio n. 2
0
 def test_check_sent_len(self):
     orig_length = 10
     class_target = 2
     original_sent = self.generate_sentence(orig_length)
     original_sent.sentence_classification['sc'] = SentenceTaskData(
         class_target, 0)
     original_sent.sentence_classification['sc_ok'] = SentenceTaskData(
         class_target + 1, 5)
     assert check_sent_len(original_sent, min_len=10,
                           max_len=None) is not None
     assert check_sent_len(original_sent, min_len=11, max_len=None) is None
     res = check_sent_len(original_sent,
                          min_len=None,
                          max_len=7,
                          from_end=False)
     assert len(res.tokens) == len(res.padding_mask) == len(
         res.token_classification['lm'].target) == len(
             res.token_classification['lm'].target_mask) == 7
     assert res.tokens[0] == original_sent.tokens[3]
     assert set(res.sentence_classification.keys()) == {'sc_ok'}
     assert res.sentence_classification['sc_ok'].target == class_target + 1
     assert res.sentence_classification['sc_ok'].target_index == 5 - 3
Esempio n. 3
0
 def _encode_line(line: str) -> Optional[Sentence]:
     return check_sent_len(
         msk_sentence(text_encoder.encode(line.rstrip()),
                      len(text_encoder), keep_prob, mask_prob,
                      rand_prob), _min_len, _max_len)