def load_file(self, file_path): with open(file_path, 'r', encoding='utf-8') as fin: for i, line in enumerate(fin): cols = convert_to_unicode(line).strip().split(";") cols = list(map(lambda x: list(map(int, x.split(" "))), cols)) if len(cols) > 3: cols = cols[:3] token_ids, type_ids, pos_ids = cols if self.mode == 'test': tgt_start_idx = len(cols[0]) else: tgt_start_idx = token_ids.index(self.bos_id, 1) data_id = i sample = [token_ids, type_ids, pos_ids, tgt_start_idx] yield sample
def create_training_instances(input_files, tokenizer, max_seq_length, dupe_factor, short_seq_prob, masked_lm_prob, max_predictions_per_seq, rng): """Create `TrainingInstance`s from raw text.""" all_documents = [[]] # Input file format: # (1) One sentence per line. These should ideally be actual sentences, not # entire paragraphs or arbitrary spans of text. (Because we use the # sentence boundaries for the "next sentence prediction" task). # (2) Blank lines between documents. Document boundaries are needed so # that the "next sentence prediction" task doesn't span between documents. for input_file in input_files: print("creating instance from {}".format(input_file)) with open(input_file, "r") as reader: while True: line = convert_to_unicode(reader.readline()) if not line: break line = line.strip() # Empty lines are used as document delimiters if not line: all_documents.append([]) # tokens = tokenizer.tokenize(line) tokens = tokenizer(line) if tokens: all_documents[-1].append(tokens) # Remove empty documents all_documents = [x for x in all_documents if x] rng.shuffle(all_documents) # vocab_words = list(tokenizer.vocab.keys()) vocab_words = list(tokenizer.vocab.token_to_idx.keys()) instances = [] for _ in range(dupe_factor): for document_index in range(len(all_documents)): instances.extend( create_instances_from_document(all_documents, document_index, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) rng.shuffle(instances) return instances
def load_file(self, file_path): """ 读取ID化的文件,文件每一行代表一个样本,一个样本由三部分id构成,分别为token_ids, type_ids和pos_ids """ with open(file_path, 'r', encoding='utf-8') as fin: for i, line in enumerate(fin): cols = convert_to_unicode(line).strip().split(";") cols = list(map(lambda x: list(map(int, x.split(" "))), cols)) if len(cols) > 3: cols = cols[:3] token_ids, type_ids, pos_ids = cols # 找打label序列的起始位置 if self.mode == 'test': tgt_start_idx = len(cols[0]) else: tgt_start_idx = token_ids.index(self.bos_id, 1) # data_id = i sample = [token_ids, type_ids, pos_ids, tgt_start_idx] yield sample