def prepare_ref(lines: List[str], ltp_tokenizer: LTP,
                bert_tokenizer: BertTokenizer):
    ltp_res = []

    for i in range(0, len(lines), 100):
        res = ltp_tokenizer.seg(lines[i:i + 100])[0]
        res = [get_chinese_word(r) for r in res]
        ltp_res.extend(res)
    assert len(ltp_res) == len(lines)

    bert_res = []
    for i in range(0, len(lines), 100):
        res = bert_tokenizer(lines[i:i + 100],
                             add_special_tokens=True,
                             truncation=True,
                             max_length=512)
        bert_res.extend(res["input_ids"])
    assert len(bert_res) == len(lines)

    ref_ids = []
    for input_ids, chinese_word in zip(bert_res, ltp_res):

        input_tokens = []
        for id in input_ids:
            token = bert_tokenizer._convert_id_to_token(id)
            input_tokens.append(token)
        input_tokens = add_sub_symbol(input_tokens, chinese_word)
        ref_id = []
        # We only save pos of chinese subwords start with ##, which mean is part of a whole word.
        for i, token in enumerate(input_tokens):
            if token[:2] == "##":
                clean_token = token[2:]
                # save chinese tokens' pos
                if len(clean_token) == 1 and _is_chinese_char(
                        ord(clean_token)):
                    ref_id.append(i)
        ref_ids.append(ref_id)

    assert len(ref_ids) == len(bert_res)

    return ref_ids
def prepare_ref(lines: List[str],
                ltp_tokenizer: LTP,
                bert_tokenizer: BertTokenizer,
                batch_size=1000):
    """
    Args:
        lines:  每行一个中文段落,
        ltp_tokenizer: ltp的tokenizer处理器
        bert_tokenizer:  bert的tokenizer处理器
    Returns:

    """
    ltp_res = []
    # batch_size等于100,每次处理100行,
    print(f"开始用ltp模型进行分词处理...")
    for i in tqdm(range(0, len(lines), batch_size)):
        #调用ltp进行分词
        res = ltp_tokenizer.seg(lines[i:i + batch_size])[0]
        #过滤出分词后都是中文的部分
        res = [get_chinese_word(r) for r in res]
        #加到ltp_res
        ltp_res.extend(res)
    assert len(ltp_res) == len(lines)
    # eg: ltp_res中的文本处理的结果 [ ['效果', '一直', '用户', '感觉'],....]
    #bert也进行tokenizer, 每次处理100行
    print(f"开始用bert tokenizer模型进行token处理...")
    bert_res = []
    for i in tqdm(range(0, len(lines), batch_size)):
        res = bert_tokenizer(lines[i:i + batch_size],
                             add_special_tokens=True,
                             truncation=True,
                             max_length=512)
        bert_res.extend(res["input_ids"])
    # eg: bert_res [ [101, 5439, 4500, 2787, 749, 8024, 671, 4684, 1762, 4500, 4007, 2051, 8024, 2697, 6230, 2190, 2971, 4576, 2971, 3779, 3126, 3362, 2923, 1962, 4638, 102]...]
    #确保行数相同
    print(f"开始生成对应关系")
    assert len(bert_res) == len(lines)
    print_num = 5
    ref_ids = []
    for input_ids, chinese_word in zip(bert_res, ltp_res):
        input_tokens = []
        for id in input_ids:
            token = bert_tokenizer._convert_id_to_token(id)
            input_tokens.append(token)
        # eg : ['[CLS]', '古', '##龙', '洗', '发', '##水', ',', '洗', '完', '头', '##发', '不', '干', '##燥', '、', '也', '不', '容', '##易', '油', '、', '不', '痒', ',', '味', '##道', '持', '##久', ',', '非', '##常', '柔', '##顺', ',', '而', '##且', '泡', '##泡', '很', '容', '##易', '冲', '##洗', '干', '##净', '泡', '##沫', '非', '##常', '细', '##腻', ',', '洗', '后', '头', '##发', '很', '滑', '很', '顺', ',', '洗', '了', '之', '##后', '就', '头', '##发', '很', '蓬', '##松', ',', '很', '香', ',', '而', '##且', '我', '洗', '了', '是', '没', '##有', '头', '##皮', '##屑', '的', '[SEP]']
        input_tokens = add_sub_symbol(input_tokens, chinese_word)
        ref_id = []
        # 我们只保存以##开头的中文子词的位置,这意味着它是全词的一部分。
        for i, token in enumerate(input_tokens):
            if token[:2] == "##":
                clean_token = token[2:]
                # 只保存中文子词的后半部分,把和bert的对应关系,保存到ref_id中,ref_id是这个句子的所有子词的后半部分映射
                if len(clean_token) == 1 and _is_chinese_char(
                        ord(clean_token)):
                    ref_id.append(i)
        #打印前5个示例
        if print_num > 0:
            example_num = 5 - print_num
            print(f"第{example_num}个样本是: {lines[example_num]}")
            print(f"第{example_num}个样本的ltp分词后结果: {ltp_res[example_num]}")
            print(
                f"第{example_num}个样本的bert toknizer后结果: {bert_res[example_num]}")
            print(
                f"第{example_num}个样本的bert toknizer被ltp的全词处理后的结果: {input_tokens}"
            )
            print(
                f"第{example_num}个样本的bert的token对应的子词的后半部分的位置的最终的ref_id: {ref_id}"
            )
            print()
            print_num -= 1
        ref_ids.append(ref_id)
    #判断每个句子的子词的映射关系都保存了
    assert len(ref_ids) == len(bert_res)

    return ref_ids