Example #1
0
class WordpieceTokenizer:

    def __init__(self, vocab_path, strip_accents, clean_text, lowercase):
        common_params = {'strip_accents': strip_accents, 'clean_text': clean_text, 'lowercase': lowercase}
        self._tokenizer = BertTokenizerFast(
            vocab_file=vocab_path, **common_params
        )

    def encode(self, text, add_cls_token=True):
        ids = []
        if add_cls_token:
            ids.append(self.cls_token_id)
        ids.extend(self._tokenizer.encode(text, add_special_tokens=False))
        return ids

    @property
    def cls_token_id(self):
        return self._tokenizer.cls_token_id

    @property
    def pad_token_id(self):
        return self._tokenizer.pad_token_id

    @property
    def sep_token_id(self):
        return self._tokenizer.sep_token_id

    @property
    def vocab_size(self):
        return self._tokenizer.vocab_size
Example #2
0
def main():
    args = set_args()
    logger = create_logger(args)
    # 当用户使用GPU,并且GPU可用时
    args.cuda = torch.cuda.is_available() and not args.no_cuda
    device = 'cuda' if args.cuda else 'cpu'
    logger.info('using device:{}'.format(device))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
    tokenizer = BertTokenizerFast(vocab_file=args.vocab_path,
                                  sep_token="[SEP]",
                                  pad_token="[PAD]",
                                  cls_token="[CLS]")
    # tokenizer = BertTokenizer(vocab_file=args.voca_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_path)
    model = model.to(device)
    model.eval()
    if args.save_samples_path:
        if not os.path.exists(args.save_samples_path):
            os.makedirs(args.save_samples_path)
        samples_file = open(args.save_samples_path + '/samples.txt',
                            'a',
                            encoding='utf8')
        samples_file.write("聊天记录{}:\n".format(datetime.now()))
    # 存储聊天记录,每个utterance以token的id的形式进行存储
    history = []
    print('开始和chatbot聊天,输入CTRL + Z以退出')

    while True:
        try:
            text = input("user:"******"你好"
            if args.save_samples_path:
                samples_file.write("user:{}\n".format(text))
            text_ids = tokenizer.encode(text, add_special_tokens=False)
            history.append(text_ids)
            input_ids = [tokenizer.cls_token_id]  # 每个input以[CLS]为开头

            for history_id, history_utr in enumerate(
                    history[-args.max_history_len:]):
                input_ids.extend(history_utr)
                input_ids.append(tokenizer.sep_token_id)
            input_ids = torch.tensor(input_ids).long().to(device)
            input_ids = input_ids.unsqueeze(0)
            response = []  # 根据context,生成的response
            # 最多生成max_len个token
            for _ in range(args.max_len):
                outputs = model(input_ids=input_ids)
                logits = outputs.logits
                next_token_logits = logits[0, -1, :]
                # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
                for id in set(response):
                    next_token_logits[id] /= args.repetition_penalty
                next_token_logits = next_token_logits / args.temperature
                # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
                next_token_logits[tokenizer.convert_tokens_to_ids(
                    '[UNK]')] = -float('Inf')
                filtered_logits = top_k_top_p_filtering(next_token_logits,
                                                        top_k=args.topk,
                                                        top_p=args.topp)
                # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
                next_token = torch.multinomial(F.softmax(filtered_logits,
                                                         dim=-1),
                                               num_samples=1)
                if next_token == tokenizer.sep_token_id:  # 遇到[SEP]则表明response生成结束
                    break
                response.append(next_token.item())
                input_ids = torch.cat((input_ids, next_token.unsqueeze(0)),
                                      dim=1)
                # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
                # print("his_text:{}".format(his_text))
            history.append(response)
            text = tokenizer.convert_ids_to_tokens(response)
            print("chatbot:" + "".join(text))
            if args.save_samples_path:
                samples_file.write("chatbot:{}\n".format("".join(text)))
        except KeyboardInterrupt:
            if args.save_samples_path:
                samples_file.close()
            break
Example #3
0
class WordpieceTokenizer(BaseTokenizer):
    def __init__(self,
                 vocab_path,
                 strip_accents,
                 clean_text,
                 lowercase,
                 from_pretrained=False):
        common_params = {
            'strip_accents': strip_accents,
            'clean_text': clean_text,
            'lowercase': lowercase
        }
        if from_pretrained:
            self._tokenizer = BertTokenizerFast.from_pretrained(
                pretrained_model_name_or_path=vocab_path, **common_params)
        else:
            self._tokenizer = BertTokenizerFast(vocab_file=vocab_path,
                                                **common_params)

    @classmethod
    def from_corpus(cls, corpus, corpus_save_path, tokenizer_save_path,
                    tokenizer_name, vocab_size, min_frequency, strip_accents,
                    clean_text, lowercase):
        with open(corpus_save_path, 'wb') as f:
            f.write('\n'.join(corpus).encode())

        tokenizer = BertWordPieceTokenizer(
            strip_accents=strip_accents,
            clean_text=clean_text,
            lowercase=lowercase,
        )
        tokenizer.train(
            [corpus_save_path],
            vocab_size=vocab_size,
            min_frequency=min_frequency,
            show_progress=True,
            special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'],
            wordpieces_prefix="##",
        )

        if os.path.exists(tokenizer_save_path):
            shutil.rmtree(tokenizer_save_path)
        os.mkdir(tokenizer_save_path)

        tokenizer.save_model(tokenizer_save_path, tokenizer_name)
        vocab_path = os.path.join(tokenizer_save_path,
                                  f'{tokenizer_name}-vocab.txt')
        return cls(vocab_path, strip_accents, clean_text, lowercase)

    def encode(self, text, add_cls_token=True):
        ids = []
        if add_cls_token:
            ids.append(self.cls_token_id)
        ids.extend(self._tokenizer.encode(text, add_special_tokens=False))
        return ids

    @property
    def cls_token_id(self):
        return self._tokenizer.cls_token_id

    @property
    def pad_token_id(self):
        return self._tokenizer.pad_token_id

    @property
    def mask_token_id(self):
        return self._tokenizer.mask_token_id

    @property
    def sep_token_id(self):
        return self._tokenizer.sep_token_id

    @property
    def unk_token_id(self):
        return self._tokenizer.unk_token_id

    @property
    def vocab_size(self):
        return self._tokenizer.vocab_size
Example #4
0
def preprocess():
    """
    对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
    """
    # 设置参数
    parser = argparse.ArgumentParser()
    parser.add_argument('--vocab_path',
                        default='vocab/vocab.txt',
                        type=str,
                        required=False,
                        help='词表路径')
    parser.add_argument('--log_path',
                        default='data/preprocess.log',
                        type=str,
                        required=False,
                        help='训练日志存放位置')
    parser.add_argument('--train_path',
                        default='data/train.txt',
                        type=str,
                        required=False,
                        help='训练日志存放位置')
    parser.add_argument('--save_path',
                        default='data/train.pkl',
                        type=str,
                        required=False,
                        help='tokenize的训练数据集')
    args = parser.parse_args()

    # 初始化日志对象
    logger = create_logger(args.log_path)

    # 初始化tokenizer
    tokenizer = BertTokenizerFast(vocab_file=args.vocab_path,
                                  sep_token="[SEP]",
                                  pad_token="[PAD]",
                                  cls_token="[CLS]")
    sep_id = tokenizer.sep_token_id
    cls_id = tokenizer.cls_token_id
    logger.info("preprocessing data,data path:{}, save path:{}".format(
        args.train_path, args.save_path))

    # 读取训练数据集
    with open(args.train_path, 'rb') as f:
        data = f.read().decode("utf-8")

    # 需要区分linux和windows环境下的换行符
    if "\r\n" in data:
        train_data = data.split("\r\n\r\n")
    else:
        train_data = data.split("\n\n")
    logger.info("there are {} dialogue in dataset".format(len(train_data)))

    # 开始进行tokenize
    # 保存所有的对话数据,每条数据的格式为:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
    dialogue_len = []  # 记录所有对话tokenize之后的长度,用于统计中位数与均值
    dialogue_list = []
    with open(args.save_path, "w", encoding="utf-8") as f:
        for index, dialogue in enumerate(tqdm(train_data)):
            if "\r\n" in data:
                utterances = dialogue.split("\r\n")
            else:
                utterances = dialogue.split("\n")

            input_ids = [cls_id]  # 每个dialogue以[CLS]开头
            for utterance in utterances:
                input_ids += tokenizer.encode(utterance,
                                              add_special_tokens=False)
                input_ids.append(sep_id)  # 每个utterance之后添加[SEP],表示utterance结束
            dialogue_len.append(len(input_ids))
            dialogue_list.append(input_ids)
    len_mean = np.mean(dialogue_len)
    len_median = np.median(dialogue_len)
    len_max = np.max(dialogue_len)
    with open(args.save_path, "wb") as f:
        pickle.dump(dialogue_list, f)
    logger.info("finish preprocessing data,the result is stored in {}".format(
        args.save_path))
    logger.info(
        "mean of dialogue len:{},median of dialogue len:{},max len:{}".format(
            len_mean, len_median, len_max))