def build_instances(self, all_documents): instances = [] instances_num = len(all_documents) // self.seq_length for i in range(instances_num): src = all_documents[i * self.seq_length:(i + 1) * self.seq_length] seg_pos = [len(src)] if not self.dynamic_masking: src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) instance = (src, tgt, seg_pos) else: instance = (src, seg_pos) instances.append(instance) src = all_documents[instances_num * self.seq_length:] seg_pos = [len(src)] while len(src) != self.seq_length: src.append(self.vocab.get(PAD_TOKEN)) if not self.dynamic_masking: src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) instance = (src, tgt, seg_pos) else: instance = (src, seg_pos) instances.append(instance) return instances
def create_ins_from_doc(self, document): max_num_tokens = self.seq_length - 3 target_seq_length = max_num_tokens if random.random() < self.short_seq_prob: target_seq_length = random.randint(2, max_num_tokens) instances = [] current_chunk = [] current_length = 0 i = 0 while i < len(document): segment = document[i] current_chunk.append(segment) current_length += len(segment) if i == len(document) - 1 or current_length >= target_seq_length: if current_chunk: a_end = 1 if len(current_chunk) >= 2: a_end = random.randint(1, len(current_chunk) - 1) tokens_a = [] for j in range(a_end): tokens_a.extend(current_chunk[j]) tokens_b = [] is_wrong_order = 0 for j in range(a_end, len(current_chunk)): tokens_b.extend(current_chunk[j]) if random.random() < 0.5: is_wrong_order = 1 tmp = tokens_a tokens_a = tokens_b tokens_b = tmp truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) src = [] src.append(self.vocab.get(CLS_TOKEN)) src.extend(tokens_a) src.append(self.vocab.get(SEP_TOKEN)) seg_pos = [len(src)] src.extend(tokens_b) src.append(self.vocab.get(SEP_TOKEN)) seg_pos.append(len(src)) while len(src) != self.seq_length: src.append(self.vocab.get(PAD_TOKEN)) if not self.dynamic_masking: src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) instance = (src, tgt_mlm, is_wrong_order, seg_pos) else: instance = (src, is_wrong_order, seg_pos) instances.append(instance) current_chunk = [] current_length = 0 i += 1 return instances
def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start:self.start + self.batch_size] self.start += self.batch_size src = [] tgt_in = [] tgt_out = [] seg = [] for _, ins in enumerate(instances): src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) tgt_single, pad_num = ins[1] for _ in range(pad_num): tgt_single.append(self.vocab.get(PAD_TOKEN)) src_single, _ = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) seg_pos = ins[2][0] tgt_in.append(tgt_single[:-1]) tgt_out.append(tgt_single[1:]) MASK_ID = self.vocab.get(MASK_TOKEN) src_with_span_mask = [] for token_id in src_single: if token_id == MASK_ID: if len(src_with_span_mask ) > 0 and src_with_span_mask[-1] == MASK_ID: seg_pos -= 1 else: src_with_span_mask.append(MASK_ID) else: src_with_span_mask.append(token_id) while len(src_with_span_mask) < len(src_single): src_with_span_mask.append(self.vocab.get(PAD_TOKEN)) seg.append([1] * seg_pos + [0] * (len(src_single) - seg_pos)) src.append(src_with_span_mask) yield torch.LongTensor(src), \ torch.LongTensor(tgt_in), \ torch.LongTensor(tgt_out), \ torch.LongTensor(seg)
def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start:self.start + self.batch_size] self.start += self.batch_size src = [] tgt_mlm = [] is_next = [] seg = [] masked_words_num = 0 for ins in instances: src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) if len(ins) == 4: src.append(src_single) masked_words_num += len(ins[1]) tgt_mlm.append([0] * len(src_single)) for mask in ins[1]: tgt_mlm[-1][mask[0]] = mask[1] is_next.append(ins[2]) seg.append([1] * ins[3][0] + [2] * (ins[3][1] - ins[3][0]) + [0] * pad_num) else: src_single, tgt_mlm_single = mask_seq( src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) masked_words_num += len(tgt_mlm_single) src.append(src_single) tgt_mlm.append([0] * len(src_single)) for mask in tgt_mlm_single: tgt_mlm[-1][mask[0]] = mask[1] is_next.append(ins[1]) seg.append([1] * ins[2][0] + [2] * (ins[2][1] - ins[2][0]) + [0] * pad_num) if masked_words_num == 0: continue yield torch.LongTensor(src), \ torch.LongTensor(tgt_mlm), \ torch.LongTensor(is_next), \ torch.LongTensor(seg)
def build_instances(self, all_documents): instances = [] instances_num = len(all_documents) // self.seq_length for i in range(instances_num): src = all_documents[i * self.seq_length:(i + 1) * self.seq_length] seg_pos = [len(src)] if not self.dynamic_masking: src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) instance = ((src, 0), tgt, seg_pos) else: instance = ((src, 0), seg_pos) instances.append(instance) src = all_documents[instances_num * self.seq_length:] if len(src) == 0: return instances seg_pos = [len(src)] pad_num = self.seq_length - len(src) if not self.dynamic_masking: src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) instance = ((src, pad_num), tgt, seg_pos) else: instance = ((src, pad_num), seg_pos) instances.append(instance) return instances
def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start:self.start + self.batch_size] self.start += self.batch_size src = [] tgt = [] seg = [] masked_words_num = 0 for ins in instances: if len(ins) == 3: src.append(ins[0]) masked_words_num += len(ins[1]) tgt.append([0] * len(ins[0])) for mask in ins[1]: tgt[-1][mask[0]] = mask[1] seg.append([1] * ins[2][0] + [0] * (len(ins[0]) - ins[2][0])) else: src_single, tgt_single = mask_seq(ins[0], self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) masked_words_num += len(tgt_single) src.append(src_single) tgt.append([0] * len(ins[0])) for mask in tgt_single: tgt[-1][mask[0]] = mask[1] seg.append([1] * ins[1][0] + [0] * (len(ins[0]) - ins[1][0])) if masked_words_num == 0: continue yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ torch.LongTensor(seg)
def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start:self.start + self.batch_size] self.start += self.batch_size src = [] tgt_in = [] tgt_out = [] seg = [] tgt_seq_length = 0 for _, ins in enumerate(instances): src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) if len(ins) == 3: tgt_single = ins[1] seg.append([1] * ins[2][0] + [0] * pad_num) else: src_single, tgt_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) seg.append([1] * ins[1][0] + [0] * pad_num) MASK_ID = self.vocab.get(MASK_TOKEN) SENTINEL_ID = self.vocab.get(SENTINEL_TOKEN) PAD_ID = self.vocab.get(PAD_TOKEN) for src_index, _ in tgt_single: if src_single[src_index] != MASK_ID: src_single[src_index] = MASK_ID tgt_in_single = [self.vocab.get(CLS_TOKEN)] mask_index = 0 src_with_sentinel = [] for token_id in src_single: if token_id == MASK_ID: if len(src_with_sentinel) > 0 and src_with_sentinel[ -1] == (SENTINEL_ID - 1): pass else: src_with_sentinel.append(SENTINEL_ID) tgt_in_single.append(SENTINEL_ID) if SENTINEL_ID < len(self.vocab) - 1: SENTINEL_ID += 1 tgt_in_single.append(tgt_single[mask_index][1]) mask_index += 1 else: src_with_sentinel.append(token_id) tgt_in_single.append(SENTINEL_ID) tgt_in_single.append(self.vocab.get(SEP_TOKEN)) while len(src_with_sentinel) < len(src_single): src_with_sentinel.append(PAD_ID) if len(tgt_in_single) > tgt_seq_length: tgt_seq_length = len(tgt_in_single) src.append(src_with_sentinel) tgt_in.append(tgt_in_single) tgt_out.append(tgt_in[-1][1:] + [PAD_ID]) for i in range(len(tgt_in)): while len(tgt_in[i]) != tgt_seq_length: tgt_in[i].append(PAD_ID) tgt_out[i].append(PAD_ID) yield torch.LongTensor(src), \ torch.LongTensor(tgt_in), \ torch.LongTensor(tgt_out), \ torch.LongTensor(seg)
def create_ins_from_doc(self, all_documents, document_index): document = all_documents[document_index] max_num_tokens = self.seq_length - 3 target_seq_length = max_num_tokens if random.random() < self.short_seq_prob: target_seq_length = random.randint(2, max_num_tokens) instances = [] current_chunk = [] current_length = 0 i = 0 while i < len(document): segment = document[i] current_chunk.append(segment) current_length += len(segment) if i == len(document) - 1 or current_length >= target_seq_length: if current_chunk: a_end = 1 if len(current_chunk) >= 2: a_end = random.randint(1, len(current_chunk) - 1) tokens_a = [] for j in range(a_end): tokens_a.extend(current_chunk[j]) tokens_b = [] is_random_next = 0 if len(current_chunk) == 1 or random.random() < 0.5: is_random_next = 1 target_b_length = target_seq_length - len(tokens_a) for _ in range(10): random_document_index = random.randint( 0, len(all_documents) - 1) if random_document_index != document_index: break random_document = all_documents[random_document_index] random_start = random.randint(0, len(random_document) - 1) for j in range(random_start, len(random_document)): tokens_b.extend(random_document[j]) if len(tokens_b) >= target_b_length: break num_unused_segments = len(current_chunk) - a_end i -= num_unused_segments else: is_random_next = 0 for j in range(a_end, len(current_chunk)): tokens_b.extend(current_chunk[j]) truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) src = [] src.append(self.vocab.get(CLS_TOKEN)) src.extend(tokens_a) src.append(self.vocab.get(SEP_TOKEN)) seg_pos = [len(src)] src.extend(tokens_b) src.append(self.vocab.get(SEP_TOKEN)) seg_pos.append(len(src)) while len(src) != self.seq_length: src.append(self.vocab.get(PAD_TOKEN)) if not self.dynamic_masking: src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) instance = (src, tgt_mlm, is_random_next, seg_pos) else: instance = (src, is_random_next, seg_pos) instances.append(instance) current_chunk = [] current_length = 0 i += 1 return instances
def worker(self, proc_id, start, end): print("Worker %d is building dataset ... " % proc_id) set_seed(self.seed) dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb") pos = 0 with open(self.corpus_path, mode="r", encoding="utf-8") as f: while pos < start: f.readline() pos += 1 while True: line = f.readline() pos += 1 line = line.strip().split('\t') if len(line) == 2: label = int(line[0]) text = line[1] src = [self.vocab.get(CLS_TOKEN) ] + self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(text)) + [ self.vocab.get(SEP_TOKEN) ] tgt_cls = label seg_pos = [len(src)] elif len(line) == 3: # For sentence pair input. label = int(line[0]) text_a, text_b = line[1], line[2] src_a = self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(text_a)) src_a = [self.vocab.get(CLS_TOKEN) ] + src_a + [self.vocab.get(SEP_TOKEN)] src_b = self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(text_b)) src_b = src_b + [self.vocab.get(SEP_TOKEN)] src = src_a + src_b tgt_cls = label seg_pos = [len(src_a)] + [len(src_b)] else: if pos >= end: break continue if len(src) >= self.seq_length: pad_num = 0 src = (src[:self.seq_length], pad_num) if len(seg_pos) == 1: seg_pos = [self.seq_length] else: if len(src_a) >= self.seq_length: seg_pos = [self.seq_length] else: seg_pos = [len(src_a) ] + [self.seq_length - len(src_a)] else: pad_num = self.seq_length - len(src) src = (src, pad_num) if not self.dynamic_masking: src_single, pad_num = src src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) src = (src_single, pad_num) instance = (src, tgt_mlm, tgt_cls, seg_pos) else: instance = (src, tgt_cls, seg_pos) pickle.dump(instance, dataset_writer) if pos >= end: break dataset_writer.close()