class WordpieceTokenizer: def __init__(self, vocab_path, strip_accents, clean_text, lowercase): common_params = {'strip_accents': strip_accents, 'clean_text': clean_text, 'lowercase': lowercase} self._tokenizer = BertTokenizerFast( vocab_file=vocab_path, **common_params ) def encode(self, text, add_cls_token=True): ids = [] if add_cls_token: ids.append(self.cls_token_id) ids.extend(self._tokenizer.encode(text, add_special_tokens=False)) return ids @property def cls_token_id(self): return self._tokenizer.cls_token_id @property def pad_token_id(self): return self._tokenizer.pad_token_id @property def sep_token_id(self): return self._tokenizer.sep_token_id @property def vocab_size(self): return self._tokenizer.vocab_size
def main(): args = set_args() logger = create_logger(args) # 当用户使用GPU,并且GPU可用时 args.cuda = torch.cuda.is_available() and not args.no_cuda device = 'cuda' if args.cuda else 'cpu' logger.info('using device:{}'.format(device)) os.environ["CUDA_VISIBLE_DEVICES"] = args.device tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") # tokenizer = BertTokenizer(vocab_file=args.voca_path) model = GPT2LMHeadModel.from_pretrained(args.model_path) model = model.to(device) model.eval() if args.save_samples_path: if not os.path.exists(args.save_samples_path): os.makedirs(args.save_samples_path) samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8') samples_file.write("聊天记录{}:\n".format(datetime.now())) # 存储聊天记录,每个utterance以token的id的形式进行存储 history = [] print('开始和chatbot聊天,输入CTRL + Z以退出') while True: try: text = input("user:"******"你好" if args.save_samples_path: samples_file.write("user:{}\n".format(text)) text_ids = tokenizer.encode(text, add_special_tokens=False) history.append(text_ids) input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头 for history_id, history_utr in enumerate( history[-args.max_history_len:]): input_ids.extend(history_utr) input_ids.append(tokenizer.sep_token_id) input_ids = torch.tensor(input_ids).long().to(device) input_ids = input_ids.unsqueeze(0) response = [] # 根据context,生成的response # 最多生成max_len个token for _ in range(args.max_len): outputs = model(input_ids=input_ids) logits = outputs.logits next_token_logits = logits[0, -1, :] # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率 for id in set(response): next_token_logits[id] /= args.repetition_penalty next_token_logits = next_token_logits / args.temperature # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token next_token_logits[tokenizer.convert_tokens_to_ids( '[UNK]')] = -float('Inf') filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp) # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束 break response.append(next_token.item()) input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1) # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist()) # print("his_text:{}".format(his_text)) history.append(response) text = tokenizer.convert_ids_to_tokens(response) print("chatbot:" + "".join(text)) if args.save_samples_path: samples_file.write("chatbot:{}\n".format("".join(text))) except KeyboardInterrupt: if args.save_samples_path: samples_file.close() break
class WordpieceTokenizer(BaseTokenizer): def __init__(self, vocab_path, strip_accents, clean_text, lowercase, from_pretrained=False): common_params = { 'strip_accents': strip_accents, 'clean_text': clean_text, 'lowercase': lowercase } if from_pretrained: self._tokenizer = BertTokenizerFast.from_pretrained( pretrained_model_name_or_path=vocab_path, **common_params) else: self._tokenizer = BertTokenizerFast(vocab_file=vocab_path, **common_params) @classmethod def from_corpus(cls, corpus, corpus_save_path, tokenizer_save_path, tokenizer_name, vocab_size, min_frequency, strip_accents, clean_text, lowercase): with open(corpus_save_path, 'wb') as f: f.write('\n'.join(corpus).encode()) tokenizer = BertWordPieceTokenizer( strip_accents=strip_accents, clean_text=clean_text, lowercase=lowercase, ) tokenizer.train( [corpus_save_path], vocab_size=vocab_size, min_frequency=min_frequency, show_progress=True, special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'], wordpieces_prefix="##", ) if os.path.exists(tokenizer_save_path): shutil.rmtree(tokenizer_save_path) os.mkdir(tokenizer_save_path) tokenizer.save_model(tokenizer_save_path, tokenizer_name) vocab_path = os.path.join(tokenizer_save_path, f'{tokenizer_name}-vocab.txt') return cls(vocab_path, strip_accents, clean_text, lowercase) def encode(self, text, add_cls_token=True): ids = [] if add_cls_token: ids.append(self.cls_token_id) ids.extend(self._tokenizer.encode(text, add_special_tokens=False)) return ids @property def cls_token_id(self): return self._tokenizer.cls_token_id @property def pad_token_id(self): return self._tokenizer.pad_token_id @property def mask_token_id(self): return self._tokenizer.mask_token_id @property def sep_token_id(self): return self._tokenizer.sep_token_id @property def unk_token_id(self): return self._tokenizer.unk_token_id @property def vocab_size(self): return self._tokenizer.vocab_size
def preprocess(): """ 对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]" """ # 设置参数 parser = argparse.ArgumentParser() parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, help='词表路径') parser.add_argument('--log_path', default='data/preprocess.log', type=str, required=False, help='训练日志存放位置') parser.add_argument('--train_path', default='data/train.txt', type=str, required=False, help='训练日志存放位置') parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='tokenize的训练数据集') args = parser.parse_args() # 初始化日志对象 logger = create_logger(args.log_path) # 初始化tokenizer tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") sep_id = tokenizer.sep_token_id cls_id = tokenizer.cls_token_id logger.info("preprocessing data,data path:{}, save path:{}".format( args.train_path, args.save_path)) # 读取训练数据集 with open(args.train_path, 'rb') as f: data = f.read().decode("utf-8") # 需要区分linux和windows环境下的换行符 if "\r\n" in data: train_data = data.split("\r\n\r\n") else: train_data = data.split("\n\n") logger.info("there are {} dialogue in dataset".format(len(train_data))) # 开始进行tokenize # 保存所有的对话数据,每条数据的格式为:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]" dialogue_len = [] # 记录所有对话tokenize之后的长度,用于统计中位数与均值 dialogue_list = [] with open(args.save_path, "w", encoding="utf-8") as f: for index, dialogue in enumerate(tqdm(train_data)): if "\r\n" in data: utterances = dialogue.split("\r\n") else: utterances = dialogue.split("\n") input_ids = [cls_id] # 每个dialogue以[CLS]开头 for utterance in utterances: input_ids += tokenizer.encode(utterance, add_special_tokens=False) input_ids.append(sep_id) # 每个utterance之后添加[SEP],表示utterance结束 dialogue_len.append(len(input_ids)) dialogue_list.append(input_ids) len_mean = np.mean(dialogue_len) len_median = np.median(dialogue_len) len_max = np.max(dialogue_len) with open(args.save_path, "wb") as f: pickle.dump(dialogue_list, f) logger.info("finish preprocessing data,the result is stored in {}".format( args.save_path)) logger.info( "mean of dialogue len:{},median of dialogue len:{},max len:{}".format( len_mean, len_median, len_max))