コード例 #1
0
ファイル: preprocess.py プロジェクト: sunnyhuma171/Macadam
    def analysis_max_length(self,
                            path: str = None,
                            rate: float = 0.95,
                            length_bert_max: int = 512):
        """
        analysis max length of data, 分析最大序列的文本长度, 未统计cls,sep等
        Args:
            path: str, train data file, eg. "/home/data/textclassification/baidu_qa_2019/train.json"
            rate: float, covge rate of all datas
            length_bert_max: int, max length of bert-like model's sequence length 
        Returns:
            None
        """

        if path and os.path.exists(path):
            self.path = path
        # 如果没强制指定最大长度
        if self.length_max is None:
            len_sents = []
            with open(self.path, "r", encoding=self.encoding) as fr:
                for xy in tqdm(fr, desc="analysis max length of sentence"):
                    xy_json = json.loads(xy.strip())
                    x = xy_json.get("x")
                    first_text = x.get("text")
                    second_texts = x.get("texts2")
                    for st in second_texts:
                        first_text += st
                    len_sents.append(len(first_text))
            # 取得覆盖rate(0.95)语料的长度
            self.length_max = min(
                sorted(len_sents)[int(rate * len(len_sents))] + 2,
                length_bert_max)
            logger.info("analysis max length of sentence is {0}".format(
                self.length_max))
コード例 #2
0
ファイル: preprocess.py プロジェクト: sunnyhuma171/Macadam
 def build_vocab_y(self, data: List):
     """
     创建列别标签字典等(l2i, i2l), create dict of label
     Args:
         data: List, train data of all, eg. [{"x":{"text":"你", "texts2":["是", "是不是"]}, "y":"YES"}]
     Returns:
         None
     """
     # 统计类别标签, count label
     ys = []
     for xy in tqdm(data, desc="build dict of l2i"):
         xy_json = json.loads(xy.strip())
         y = xy_json.get("y")
         if type(y) == list:
             for yi in y:
                 if yi not in ys:
                     ys.append(yi)
         else:
             if y not in ys:
                 ys.append(y)
     # 创建字典, create dict
     for ysc in ys:
         self.l2i[ysc] = len(self.l2i)
         self.i2l[len(self.l2i) - 1] = ysc
     # ner任务的[CLS], [SEP]; 或者是"O"
     if self.task == SL and self.embed_type in EMBEDDING_TYPE:
         if self.y_start not in self.l2i:
             self.l2i[self.y_start] = len(self.l2i)
             self.i2l[len(self.l2i) - 1] = self.y_start
         if self.y_end not in self.l2i:
             self.l2i[self.y_end] = len(self.l2i)
             self.i2l[len(self.l2i) - 1] = self.y_end
     logger.info("build vocab of l2i is {0}".format(self.l2i))
コード例 #3
0
ファイル: preprocess.py プロジェクト: sunnyhuma171/Macadam
 def analysis_max_length(self,
                         data: List,
                         rate: float = 0.95,
                         length_bert_max: int = 512):
     """
     analysis max length of data, 分析最大序列的文本长度
     Args:
         data: List, train data of all, eg. [{"x":{"text":"你", "texts2":["是", "是不是"]}, "y":"YES"}]
         rate: float, covge rate of all datas
         length_bert_max: int, max length of bert-like model's sequence length 
     Returns:
         None
     """
     if self.length_max is None:
         len_sents = []
         for xy in tqdm(data, desc="analysis max length of sentence"):
             xy_json = json.loads(xy.strip())
             x = xy_json.get("x")
             first_text = x.get("text")
             second_texts = x.get("texts2")
             for st in second_texts:
                 first_text += st
             len_sents.append(len(first_text))
         self.length_max = min(
             sorted(len_sents)[int(rate * len(len_sents))] + 2,
             length_bert_max)
         logger.info("analysis max length of sentence is {0}".format(
             self.length_max))
コード例 #4
0
ファイル: preprocess.py プロジェクト: sunnyhuma171/Macadam
 def build_vocab_y(self, path: str = None):
     """
     创建列别标签字典等(l2i, i2l), create dict of label
     Args:
         path: str, train data file, eg. "/home/data/textclassification/baidu_qa_2019/train.json"
     Returns:
         None
     """
     if path and os.path.exists(path):
         self.path = path
     # 统计类别标签, count label
     ys_counter = []
     ys = []
     with open(self.path, "r", encoding=self.encoding) as fr:
         for xy in tqdm(fr, desc="build dict of l2i"):
             xy_json = json.loads(xy.strip())
             y = xy_json.get("y")
             if type(y) == list:
                 for yi in y:
                     if yi not in ys:
                         ys.append(yi)
                 ys_counter += y
             else:
                 if y not in ys:
                     ys.append(y)
                 ys_counter.append(y)
         fr.close()
     # 类别统计
     ys_counter_dict = dict(Counter(ys_counter))
     ys_counter_dict_sort = dict_sort(ys_counter_dict)
     logger.info(
         json.dumps(ys_counter_dict_sort, ensure_ascii=False, indent=4))
     # 创建字典, create dict
     for ysc in ys:
         self.l2i[ysc] = len(self.l2i)
         self.i2l[len(self.l2i) - 1] = ysc
     # ner任务的[CLS], [SEP]; 或者是"O"
     if self.task == SL and self.embed_type in EMBEDDING_TYPE:
         if self.y_start not in self.l2i:
             self.l2i[self.y_start] = len(self.l2i)
             self.i2l[len(self.l2i) - 1] = self.y_start
         if self.y_end not in self.l2i:
             self.l2i[self.y_end] = len(self.l2i)
             self.i2l[len(self.l2i) - 1] = self.y_end
     logger.info("build vocab of l2i is {0}".format(self.l2i))
コード例 #5
0
ファイル: utils.py プロジェクト: yynnxu/Macadam
def txt_write(lines: List[str],
              path: str,
              model: str = "w",
              encoding: str = "utf-8"):
    """
    Write Line of list to file
    Args:
        lines: lines of list<str> which need save
        path: path of save file, such as "txt"
        model: type of write, such as "w", "a+"
        encoding: type of encoding, such as "utf-8", "gbk"
    """

    try:
        file = open(path, model, encoding=encoding)
        file.writelines(lines)
        file.close()
    except Exception as e:
        logger.info(str(e))
コード例 #6
0
ファイル: utils.py プロジェクト: yynnxu/Macadam
def txt_read(path: str, encoding: str = "utf-8") -> List[str]:
    """
    Read Line of list form file
    Args:
        path: path of save file, such as "txt"
        encoding: type of encoding, such as "utf-8", "gbk"
    Returns:
        dict of word2vec, eg. {"macadam":[...]}
    """

    lines = []
    try:
        file = open(path, "r", encoding=encoding)
        lines = file.readlines()
        file.close()
    except Exception as e:
        logger.info(str(e))
    finally:
        return lines
コード例 #7
0
def trainer(
    path_model_dir,
    path_embed,
    path_train,
    path_dev,
    path_checkpoint,
    path_config,
    path_vocab,
    network_type="FastText",
    embed_type="BERT",
    token_type="CHAR",
    task="TC",
    is_length_max=False,
    use_onehot=True,
    use_file=False,
    layer_idx=[-1],
    length_max=128,
    embed_size=768,
    learning_rate=5e-5,
    batch_size=32,
    epochs=20,
    early_stop=3,
    decay_rate=0.999,
    decay_step=1000,
    rate=1.0,
):
    """
    train model of text-classfifcation
    Args:
        path_model_dir: str, directory of model save, eg. "/home/model/text_cnn"
        path_embed: str, directory of pre-train embedding, eg. "/home/embedding/bert"
        path_train: str, path of file(json) of train data, eg. "/home/data/text_classification/THUCNews/train.json"
        path_dev: str, path of file(json) of dev data, eg. "/home/data/text_classification/THUCNews/dev.json"
        path_checkpoint: str, path of checkpoint file of pre-train embedding
        path_config: str, path of config file of pre-train embedding
        path_vocab: str, path of vocab file of pre-train embedding
        network_type: str, network of text-classification, eg."FastText","TextCNN", "BiRNN", "RCNN", "CRNN", "SelfAttention" 
        embed_type: str, type of pre-train enbedding, eg. "Bert", "Albert", "Roberta", "Electra"
        task: str, task of model, eg. "sl"(sequence-labeling), "tc"(text-classification), "re"(relation-extraction)
        is_length_max: bool, whether update length_max with analysis corpus, eg.False 
        layer_idx: List[int], layers which you select of bert-like model, eg.[-2]
        use_onehot: bool, whether use onehot of y(label), eg.False 
        use_file:   bool, use ListPrerocessXY or FilePrerocessXY
        length_max: int, max length of sequence, eg.128 
        embed_size: int, dim of bert-like model, eg.768
        learning_rate: float, lr of training, eg.1e-3, 5e-5
        batch_size: int, samples each step when training, eg.32 
        epochs: int, max epoch of training, eg.20
        early_stop: int, stop training when metrice not insreasing, eg.3
        decay_rate: float, decay rate of lr, eg.0.999 
        decay_step: decay step of training, eg.1000
    Returns:
        None
    """
    # 获取embed和graph的类
    Embedding = embedding_map[embed_type.upper()]
    Graph = graph_map[network_type.upper()]
    print(os.environ["CUDA_VISIBLE_DEVICES"])

    # 删除先前存在的模型/embedding微调模型等
    # bert-embedding等初始化
    params = {
        "embed": {
            "path_embed": path_embed,
            "layer_idx": layer_idx,
        },
        "sharing": {
            "length_max": length_max,
            "embed_size": embed_size,
            "token_type": token_type.upper(),
        },
        "graph": {
            "loss": "categorical_crossentropy"
            if use_onehot else "sparse_categorical_crossentropy",  # 损失函数
            "use_onehot": use_onehot,  # label标签是否使用独热编码
            "use_crf": False  # 是否使用CRF, 是否存储trans(状态转移矩阵时用)
        },
        "train": {
            "learning_rate":
            learning_rate,  # 学习率, 必调参数, 对训练影响较大, word2vec一般设置1e-3, bert设置5e-5或2e-5
            "decay_rate": decay_rate,  # 学习率衰减系数, 即乘法, lr = lr * rate
            "decay_step": decay_step,  # 学习率每step步衰减, 每N个step衰减一次
            "batch_size":
            batch_size,  # 批处理尺寸, 设置过小会造成收敛困难、陷入局部最小值或震荡, 设置过大会造成泛化能力降低
            "early_stop": early_stop,  # 早停, N个轮次(epcoh)评估指标(metrics)不增长就停止训练
            "epochs": epochs,  # 训练最大轮次, 即最多训练N轮
        },
        "save": {
            "path_model_dir":
            path_model_dir,  # 模型目录, loss降低则保存的依据, save_best_only=True, save_weights_only=True
            "path_model_info": os.path.join(path_model_dir,
                                            "model_info.json"),  # 超参数文件地址
        },
        "data": {
            "train_data": path_train,  # 训练数据
            "val_data": path_dev  # 验证数据
        },
    }

    embed = Embedding(params)
    embed.build_embedding(path_checkpoint=path_checkpoint,
                          path_config=path_config,
                          path_vocab=path_vocab)
    print(os.environ["CUDA_VISIBLE_DEVICES"])

    # 模型graph初始化
    graph = Graph(params)

    logger.info("训练/验证语料读取完成")
    # 数据预处理类初始化, 1. is_length_max: 是否指定最大序列长度, 如果不指定则根据语料智能选择length_max.
    #                  2. use_file: 输入List迭代或是输入path_file迭代.
    if use_file:
        train_data = path_train
        dev_data = path_dev
        pxy = FilePrerocessXY(embedding=embed,
                              path=train_data,
                              path_dir=path_model_dir,
                              length_max=length_max if is_length_max else None,
                              use_onehot=use_onehot,
                              embed_type=embed_type,
                              task=task)
        from macadam.base.preprocess import FileGenerator as generator_xy
        logger.info("强制使用序列最大长度为{0}, 即文本最大截断或padding长度".format(length_max))
    else:
        # 训练/验证数据读取, 每行一个json格式, example: {"x":{"text":"你是谁", "texts2":["你是谁呀", "是不是"]}, "y":"YES"}
        train_data = txt_read(path_train)
        dev_data = txt_read(path_dev)
        # 只有ListPrerocessXY才支持rate(data), 训练比率
        len_train_rate = int(len(train_data) * rate)
        len_dev_rate = int(len(dev_data) * rate)
        train_data = train_data[:len_train_rate]
        dev_data = dev_data[:len_dev_rate]
        pxy = ListPrerocessXY(embedding=embed,
                              data=train_data,
                              path_dir=path_model_dir,
                              length_max=length_max if is_length_max else None,
                              use_onehot=use_onehot,
                              embed_type=embed_type,
                              task=task)
        from macadam.base.preprocess import ListGenerator as generator_xy
        logger.info("强制使用序列最大长度为{0}, 即文本最大截断或padding长度".format(length_max))

    print(os.environ["CUDA_VISIBLE_DEVICES"])
    logger.info("预处理类初始化完成")
    if not pxy.length_max:
        print(pxy.length_max)
        pxy.length_max = 33
    # 更新最大序列长度, 类别数
    graph.length_max = pxy.length_max
    graph.label = len(pxy.l2i)
    graph.hyper_parameters["sharing"]["length_max"] = graph.length_max
    graph.hyper_parameters["train"]["label"] = graph.label

    # length_max更新, ListPrerocessXY的embedding更新
    if length_max != graph.length_max and not is_length_max:
        logger.info("根据bert-embedding等的最大长度不大于512, 根据语料自动确定序列最大长度为{0}".format(
            graph.length_max))
        params["sharing"]["length_max"] = graph.length_max
        embed = Embedding(params)
        embed.build_embedding(path_checkpoint=path_checkpoint,
                              path_config=path_config,
                              path_vocab=path_vocab)
        pxy.embedding = embed
    print(os.environ["CUDA_VISIBLE_DEVICES"])

    # 更新维度空间
    graph.embed_size = embed.embed_size
    graph.hyper_parameters["sharing"]["embed_size"] = graph.embed_size

    logger.info("预训练模型加载完成")
    # graph更新
    graph.build_model(inputs=embed.model.input, outputs=embed.model.output)
    graph.create_compile()
    logger.info("网络(network or graph)初始化完成")
    logger.info("开始训练: ")
    # 训练
    time_start = time.time()
    print(os.environ["CUDA_VISIBLE_DEVICES"])

    graph.fit(pxy, generator_xy, train_data, dev_data=dev_data, rate=rate)
    time_collection = str(time.time() - time_start)
    logger.info("训练完成, 耗时:" + str(time.time() - time_start))
    return time_collection
コード例 #8
0
ファイル: tet_embed+.py プロジェクト: sumerzhang/Macadam
embed = Embedding(params)
embed.build_embedding(path_checkpoint=path_checkpoint,
                      path_config=path_config,
                      path_vocab=path_vocab)

# 训练/验证数据读取, 每行一个json格式, example: {"x":{"text":"你是谁", "texts2":["你是谁呀", "是不是"]}, "y":"YES"}
train_data = txt_read(path_train)
dev_data = txt_read(path_dev)

len_train_rate = int(len(train_data) * rate)
len_dev_rate = int(len(dev_data) * rate)

train_data = train_data[:len_train_rate]
dev_data = dev_data[:len_dev_rate]

logger.info("训练/验证语料读取完成")
# 数据预处理类初始化
preprocess_xy = ListPrerocessXY(embed,
                                train_data,
                                path_dir=path_model_dir,
                                length_max=length_max)

x = L.Lambda(lambda x: x[:, 0], name="Token-CLS")(embed.model.output)

# 最后就是softmax
outputs = L.Dense(
    len(preprocess_xy.l2i),
    activation="softmax",
    kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02))(x)
model = M.Model(embed.model.input, outputs)
model.summary(132)
コード例 #9
0
ファイル: tet_graph.py プロジェクト: yynnxu/Macadam
def train(hyper_parameters=None, use_onehot=False, rate=1):
    """
        训练函数
    :param hyper_parameters: json, 超参数
    :param rate: 比率, 抽出rate比率语料取训练
    :return: None
    """

    # 删除先前存在的模型\embedding微调模型等
    time_start = time.time()
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    os.environ["TF_KERAS"] = "1"
    path_embed = "D:/soft_install/dataset/bert-model/chinese_L-12_H-768_A-12"
    path_check_point = path_embed + "/bert_model.ckpt"
    path_config = path_embed + "/bert_config.json"
    path_vocab = path_embed + "/vocab.txt"
    length_max = 128

    params = {"embed": {"path_embed": path_embed,
                        "layer_idx": [-2],
                        },
              "sharing": {"length_max": length_max,
                          "embed_size": 768
                          },
              "graph": {"loss": "categorical_crossentropy" if use_onehot else "sparse_categorical_crossentropy",  # 损失函数
                        },
              "save": {
                  "path_model": path_model_dir,  # 模型目录, loss降低则保存的依据, save_best_only=True, save_weights_only=True
                  "path_hyper_parameters": os.path.join(path_model_dir, "hyper_parameters.json"),  # 超参数文件地址
                  "path_fineture": os.path.join(path_model_dir, "embedding.json"),  # 微调后embedding文件地址, 例如字向量、词向量、bert向量等
              },
              }
    bert_embed = BertEmbedding(params)
    bert_embed.build_embedding(path_checkpoint=path_check_point,
                               path_config=path_config,
                               path_vocab=path_vocab)

    graph = Graph(params)

    # 训练/验证数据读取, 每行一个json格式, example: {"x":{"text":"你是谁", "texts2":["你是谁呀", "是不是"]}, "y":"YES"}
    train_data = txt_read(path_train)
    dev_data = txt_read(path_dev)
    # 只有ListPrerocessXY才支持rate(data), 训练比率
    len_train_rate = int(len(train_data) * rate)
    len_dev_rate = int(len(dev_data) * rate)
    train_data = train_data[:len_train_rate]
    dev_data = dev_data[:len_dev_rate]
    pxy = ListPrerocessXY(embedding=bert_embed, data=train_data, path_dir=path_model_dir,
                          length_max=length_max, use_onehot=use_onehot, embed_type="BERT", task="TC")
    from macadam.base.preprocess import ListGenerator as generator_xy
    logger.info("强制使用序列最大长度为{0}, 即文本最大截断或padding长度".format(length_max))
    # 更新最大序列长度, 类别数
    graph.length_max = pxy.length_max
    graph.label = len(pxy.l2i)
    graph.embed_size = bert_embed.embed_size

    # shape = bert_embed.output
    graph.build_model(inputs=bert_embed.model.inputs, outputs=bert_embed.model.output)
    graph.create_compile()
    # 训练
    graph.fit(pxy, generator_xy, train_data, dev_data=dev_data)
    print("耗时:" + str(time.time()-time_start))
コード例 #10
0
ファイル: tet_preprocess_xy.py プロジェクト: yynnxu/Macadam
def preprocess(
    path_model_dir,
    path_embed,
    path_train,
    path_dev,
    path_checkpoint,
    path_config,
    path_vocab,
    network_type="CRF",
    embed_type="BERT",
    token_type="CHAR",
    task="SL",
    is_length_max=False,
    use_onehot=False,
    use_file=False,
    layer_idx=[-1],
    length_max=128,
    embed_size=768,
    learning_rate=5e-5,
    batch_size=32,
    epochs=20,
    early_stop=3,
    decay_rate=0.999,
    decay_step=1000,
    rate=1.0,
):
    """
    train model of sequence labeling
    Args:
        path_model_dir: str, directory of model save, eg. "/home/model/text_cnn"
        path_embed: str, directory of pre-train embedding, eg. "/home/embedding/bert"
        path_train: str, path of file(json) of train data, eg. "/home/data/name_entity_recognition/people_1998/train.json"
        path_dev: str, path of file(json) of dev data, eg. "/home/data/name_entity_recognition/people_1998/dev.json"
        path_checkpoint: str, path of checkpoint file of pre-train embedding
        path_config: str, path of config file of pre-train embedding
        path_vocab: str, path of vocab file of pre-train embedding
        network_type: str, network of text-classification, eg."FastText","TextCNN", "BiRNN", "RCNN", "CRNN", "SelfAttention" 
        embed_type: str, type of pre-train enbedding, eg. "Bert", "Albert", "Roberta", "Electra"
        task: str, task of model, eg. "sl"(sequence-labeling), "tc"(text-classification), "re"(relation-extraction)
        is_length_max: bool, whether update length_max with analysis corpus, eg.False 
        use_onehot: bool, whether use onehot of y(label), eg.False 
        use_file:   bool, use ListPrerocessXY or FilePrerocessXY
        layer_idx: List[int], layers which you select of bert-like model, eg.[-2]
        length_max: int, max length of sequence, eg.128 
        embed_size: int, dim of bert-like model, eg.768
        learning_rate: float, lr of training, eg.1e-3, 5e-5
        batch_size: int, samples each step when training, eg.32 
        epochs: int, max epoch of training, eg.20
        early_stop: int, stop training when metrice not insreasing, eg.3
        decay_rate: float, decay rate of lr, eg.0.999 
        decay_step: decay step of training, eg.1000
    Returns:
        None
    """
    # 获取embed和graph的类
    Embedding = embedding_map[embed_type.upper()]
    Graph = graph_map[network_type.upper()]

    # 删除先前存在的模型/embedding微调模型等
    time_start = time.time()
    # bert-embedding/graph等重要参数配置
    params = {
        "embed": {
            "path_embed": path_embed,
            "layer_idx": layer_idx,
        },
        "sharing": {
            "length_max": length_max,
            "embed_size": embed_size,
            "token_type": token_type.upper(),
        },
        "graph": {
            "loss": "categorical_crossentropy"
            if use_onehot else "sparse_categorical_crossentropy",  # 损失函数
            "use_onehot": use_onehot,  # label标签是否使用独热编码
            "use_crf": False if network_type in ["BI-LSTM-LAN"] else
            True,  # 是否使用CRF, 是否存储trans(状态转移矩阵时用)
        },
        "train": {
            "learning_rate":
            learning_rate,  # 学习率, 必调参数, 对训练影响较大, word2vec一般设置1e-3, bert设置5e-5或2e-5
            "decay_rate": decay_rate,  # 学习率衰减系数, 即乘法, lr = lr * rate
            "decay_step": decay_step,  # 学习率每step步衰减, 每N个step衰减一次
            "batch_size":
            batch_size,  # 批处理尺寸, 设置过小会造成收敛困难、陷入局部最小值或震荡, 设置过大会造成泛化能力降低
            "early_stop": early_stop,  # 早停, N个轮次(epcoh)评估指标(metrics)不增长就停止训练
            "epochs": epochs,  # 训练最大轮次, 即最多训练N轮
        },
        "save": {
            "path_model_dir":
            path_model_dir,  # 模型目录, loss降低则保存的依据, save_best_only=True, save_weights_only=True
            "path_model_info": os.path.join(path_model_dir,
                                            "model_info.json"),  # 超参数文件地址
        },
        "data": {
            "train_data": path_train,  # 训练数据
            "val_data": path_dev  # 验证数据
        },
    }
    embed = Embedding(params)
    embed.build_embedding(path_checkpoint=path_checkpoint,
                          path_config=path_config,
                          path_vocab=path_vocab)
    # 模型graph初始化
    graph = Graph(params)
    # 数据预处理类初始化, 1. is_length_max: 是否指定最大序列长度, 如果不指定则根据语料智能选择length_max.
    #                  2. use_file: 输入List迭代或是输入path_file迭代.
    if use_file:
        train_data = path_train
        dev_data = path_dev
        pxy = FilePrerocessXY(embedding=embed,
                              path=path_train,
                              path_dir=path_model_dir,
                              length_max=length_max if is_length_max else None,
                              use_onehot=use_onehot,
                              embed_type=embed_type,
                              task=task)
        from macadam.base.preprocess import FileGenerator as generator_xy
        logger.info("强制使用序列最大长度为{0}, 即文本最大截断或padding长度".format(length_max))
    else:
        # 训练/验证数据读取, 每行一个json格式, example: {"x":{"text":"你是谁", "texts2":["你是谁呀", "是不是"]}, "y":"YES"}
        train_data = txt_read(path_train)
        dev_data = txt_read(path_dev)
        # 只有ListPrerocessXY才支持rate(data), 训练比率
        len_train_rate = int(len(train_data) * rate)
        len_dev_rate = int(len(dev_data) * rate)
        train_data = train_data[:len_train_rate]
        dev_data = dev_data[:len_dev_rate]
        pxy = ListPrerocessXY(embedding=embed,
                              data=train_data,
                              path_dir=path_model_dir,
                              length_max=length_max if is_length_max else None,
                              use_onehot=use_onehot,
                              embed_type=embed_type,
                              task=task)
        from macadam.base.preprocess import ListGenerator as generator_xy
        logger.info("强制使用序列最大长度为{0}, 即文本最大截断或padding长度".format(length_max))
    logger.info("预处理类初始化完成")
    # pxy.init_params(train_data)
    graph.length_max = pxy.length_max
    graph.label = len(pxy.l2i)

    # length_max更新, ListPrerocessXY的embedding更新
    if length_max != graph.length_max and not is_length_max:
        logger.info("根据语料自动确认序列最大长度为{0}, 且bert-embedding等的最大长度不大于512".format(
            graph.length_max))
        params["sharing"]["length_max"] = graph.length_max
        embed = Embedding(params)
        embed.build_embedding(path_checkpoint=path_checkpoint,
                              path_config=path_config,
                              path_vocab=path_vocab)
        pxy.embedding = embed

    logger.info("预训练模型加载完成")
    if use_file:
        len_train = pxy.analysis_len_data(train_data)
        gxy = generator_xy(dev_data,
                           pxy,
                           batch_size=batch_size,
                           len_data=len_train)
        gxy.forfit()
    else:
        # batch_x, batch_y, preprocess
        batch_x_idx, batch_y_idx = [], []
        len_x_y_id = set()
        for td in train_data:
            line_json = json.loads(td)
            # line_json = {"x": {"text": "“旧货”不仅仅是指新货被使用才成为旧货;还包括商品的调剂,即卖出旧货的人是为了买入新货,买入旧货的人是因为符合自己的需要,不管新旧;有的商店还包括一些高档的工艺品、古董、字画、家具等商品;有的还包括新货卖不出去,企业或店主为了盘活资金,削价销售积压产品。", "texts2": []}, "y": ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"]}
            x_id = pxy.preprocess_x(line_json.get("x"))
            y_id = pxy.preprocess_y(line_json.get("y"))
            len_x_id_0 = len(x_id[0])
            len_x_id_1 = len(x_id[1])
            len_y_id = len(y_id)
            if len_x_id_0 not in len_x_y_id:
                print(line_json)
                print(len_x_id_0)
                mm = 0
                len_x_y_id.add(len_x_id_0)
            if len_x_id_1 not in len_x_y_id:
                print(line_json)
                print(len_x_id_1)
                mm = 0
            if len_y_id not in len_x_y_id:
                print(line_json)
                print(len_y_id)
                mm = 0
            batch_x_idx.append(x_id)
            batch_y_idx.append(y_id)
    logger.info("训练完成, 耗时:" + str(time.time() - time_start))