def _yield_sentence(sent1: Sentence, sent2: Optional[Sentence] = None) -> Sentence: lm = sent1.token_classification['lm'] if sent2 is None: split_idx = random.randint(_min_len // 2, len(sent1.tokens) - _min_len // 2) return Sentence( [text_encoder.bos_id] + sent1.tokens[:split_idx] + [text_encoder.del_id] + sent1.tokens[split_idx:] + [text_encoder.eos_id], [True] + sent1.padding_mask[:split_idx] + [True] + sent1.padding_mask[split_idx:] + [True], [0] * (split_idx + 2) + [1] * (1 + len(sent1.tokens) - split_idx), { 'lm': TokenTaskData([0] + lm.target[:split_idx] + [0] + lm.target[split_idx:] + [0], [False] + lm.target_mask[:split_idx] + [False] + lm.target_mask[split_idx:] + [False]) }, {}) lm_ = sent2.token_classification['lm'] return Sentence( [text_encoder.bos_id] + sent1.tokens + [text_encoder.del_id] + sent2.tokens + [text_encoder.eos_id], [True] + sent1.padding_mask + [True] + sent2.padding_mask + [True], [0] * (2 + len(sent1.tokens)) + [1] * (1 + len(sent2.tokens)), { 'lm': TokenTaskData([0] + lm.target + [0] + lm_.target + [0], [False] + lm.target_mask + [False] + lm_.target_mask + [False]) }, {})
def generate_sentence(self, length: int) -> Sentence: return Sentence( self.generate_random_seq(length), [True] * length, [0] * length, { 'lm': TokenTaskData(self.generate_random_seq(length), self.generate_random_mask(length)) }, {})
def _yield_sentence(sent: Sentence) -> Sentence: lm = sent.token_classification['lm'] return Sentence( [text_encoder.bos_id] + sent.tokens + [text_encoder.eos_id], [True] + sent.padding_mask + [True], [0] * len(sent.tokens), { 'lm': TokenTaskData([0] + lm.target + [0], [False] + lm.target_mask + [False]) }, {})
def dummy_generator(): for _ in range(steps): seq_len = random.randint(1, max_len - 1) tokens = [random.randrange(vocab_size) for i in range(seq_len)] tokens[-1] = eos_id yield Sentence( tokens=tokens, padding_mask=[True] * seq_len, segments=[0] * seq_len, token_classification={ 'lm': TokenTaskData( tokens if easy else [random.randrange(vocab_size) for i in range(seq_len)], [True] * seq_len), 'lm_untied': TokenTaskData( tokens if easy else [random.randrange(vocab_size) for i in range(seq_len)], [True] * seq_len) }, sentence_classification={ 'count': SentenceTaskData(seq_len % 2, seq_len - 1) })