示例#1
0
 def _yield_sentence(sent1: Sentence,
                     sent2: Optional[Sentence] = None) -> Sentence:
     lm = sent1.token_classification['lm']
     if sent2 is None:
         split_idx = random.randint(_min_len // 2,
                                    len(sent1.tokens) - _min_len // 2)
         return Sentence(
             [text_encoder.bos_id] + sent1.tokens[:split_idx] +
             [text_encoder.del_id] + sent1.tokens[split_idx:] +
             [text_encoder.eos_id],
             [True] + sent1.padding_mask[:split_idx] + [True] +
             sent1.padding_mask[split_idx:] + [True],
             [0] * (split_idx + 2) + [1] *
             (1 + len(sent1.tokens) - split_idx), {
                 'lm':
                 TokenTaskData([0] + lm.target[:split_idx] + [0] +
                               lm.target[split_idx:] + [0], [False] +
                               lm.target_mask[:split_idx] + [False] +
                               lm.target_mask[split_idx:] + [False])
             }, {})
     lm_ = sent2.token_classification['lm']
     return Sentence(
         [text_encoder.bos_id] + sent1.tokens + [text_encoder.del_id] +
         sent2.tokens + [text_encoder.eos_id], [True] +
         sent1.padding_mask + [True] + sent2.padding_mask + [True],
         [0] * (2 + len(sent1.tokens)) + [1] * (1 + len(sent2.tokens)),
         {
             'lm':
             TokenTaskData([0] + lm.target + [0] + lm_.target + [0],
                           [False] + lm.target_mask + [False] +
                           lm_.target_mask + [False])
         }, {})
示例#2
0
 def generate_sentence(self, length: int) -> Sentence:
     return Sentence(
         self.generate_random_seq(length), [True] * length, [0] * length, {
             'lm':
             TokenTaskData(self.generate_random_seq(length),
                           self.generate_random_mask(length))
         }, {})
示例#3
0
 def _yield_sentence(sent: Sentence) -> Sentence:
     lm = sent.token_classification['lm']
     return Sentence(
         [text_encoder.bos_id] + sent.tokens + [text_encoder.eos_id],
         [True] + sent.padding_mask + [True], [0] * len(sent.tokens), {
             'lm':
             TokenTaskData([0] + lm.target + [0],
                           [False] + lm.target_mask + [False])
         }, {})
示例#4
0
 def dummy_generator():
     for _ in range(steps):
         seq_len = random.randint(1, max_len - 1)
         tokens = [random.randrange(vocab_size) for i in range(seq_len)]
         tokens[-1] = eos_id
         yield Sentence(
             tokens=tokens,
             padding_mask=[True] * seq_len,
             segments=[0] * seq_len,
             token_classification={
                 'lm':
                 TokenTaskData(
                     tokens if easy else
                     [random.randrange(vocab_size) for i in range(seq_len)],
                     [True] * seq_len),
                 'lm_untied':
                 TokenTaskData(
                     tokens if easy else
                     [random.randrange(vocab_size) for i in range(seq_len)],
                     [True] * seq_len)
             },
             sentence_classification={
                 'count': SentenceTaskData(seq_len % 2, seq_len - 1)
             })