示例#1
0
    def fit(self, x_train, y_train, x_dev, y_dev):
        """
            训练
        :param x_train: 
        :param y_train: 
        :param x_dev: 
        :param y_dev: 
        :return: 
        """
        # 保存超参数
        self.hyper_parameters['model']['is_training'] = False  # 预测时候这些设为False
        self.hyper_parameters['model']['trainable'] = False
        self.hyper_parameters['model']['dropout'] = 1.0

        save_json(jsons=self.hyper_parameters,
                  json_path=self.path_hyper_parameters)
        # 训练模型
        self.model.fit(x_train,
                       y_train,
                       batch_size=self.batch_size,
                       epochs=self.epochs,
                       validation_data=(x_dev, y_dev),
                       shuffle=True,
                       callbacks=self.callback())
        # 保存embedding, 动态的
        if self.trainable:
            self.word_embedding.model.save(self.path_fineture)
示例#2
0
    def fit_generator(self, embed, rate=1):
        """

        :param data_fit_generator: yield, 训练数据
        :param data_dev_generator: yield, 验证数据
        :param steps_per_epoch: int, 训练一轮步数
        :param validation_steps: int, 验证一轮步数
        :return: 
        """
        # 保存超参数
        self.hyper_parameters['model']['is_training'] = False  # 预测时候这些设为False
        self.hyper_parameters['model']['trainable'] = False
        self.hyper_parameters['model']['dropout'] = 1.0

        save_json(jsons=self.hyper_parameters,
                  json_path=self.path_hyper_parameters)

        pg = PreprocessGenerator()
        _, len_train = pg.preprocess_get_label_set(
            self.hyper_parameters['data']['train_data'])
        data_fit_generator = pg.preprocess_label_ques_to_idx(
            embedding_type=self.hyper_parameters['embedding_type'],
            batch_size=self.batch_size,
            path=self.hyper_parameters['data']['train_data'],
            embed=embed,
            rate=rate)
        _, len_val = pg.preprocess_get_label_set(
            self.hyper_parameters['data']['val_data'])
        data_dev_generator = pg.preprocess_label_ques_to_idx(
            embedding_type=self.hyper_parameters['embedding_type'],
            batch_size=self.batch_size,
            path=self.hyper_parameters['data']['val_data'],
            embed=embed,
            rate=rate)
        steps_per_epoch = len_train // self.batch_size
        validation_steps = len_val // self.batch_size
        # 训练模型
        self.model.fit_generator(generator=data_fit_generator,
                                 validation_data=data_dev_generator,
                                 callbacks=self.callback(),
                                 epochs=self.epochs,
                                 steps_per_epoch=steps_per_epoch,
                                 validation_steps=validation_steps)
        # 保存embedding, 动态的
        if self.trainable:
            self.word_embedding.model.save(self.path_fineture)
    def preprocess_label_ques_to_idx(self,
                                     embedding_type,
                                     batch_size,
                                     path,
                                     embed,
                                     rate=1,
                                     epcoh=20):
        label_set, len_all = self.preprocess_get_label_set(path)
        # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用
        if not os.path.exists(self.path_fast_text_model_l2i_i2l):
            count = 0
            label2index = {}
            index2label = {}
            for label_one in label_set:
                label2index[label_one] = count
                index2label[count] = label_one
                count = count + 1

            l2i_i2l = {}
            l2i_i2l['l2i'] = label2index
            l2i_i2l['i2l'] = index2label
            save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l)
        else:
            l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l)

        # 读取数据的比例
        len_ql = int(rate * len_all)
        if len_ql <= 500:  # sample时候不生效,使得语料足够训练
            len_ql = len_all

        def process_line(line):
            # 对每一条数据操作,获取label和问句index
            line_sp = line.split(",")
            ques = str(line_sp[1]).strip().upper()
            label = str(line_sp[0]).strip().upper()
            label = "NAN" if label == "" else label
            que_embed = embed.sentence2idx(ques)
            label_zeros = [0] * len(l2i_i2l['l2i'])
            label_zeros[l2i_i2l['l2i'][label]] = 1
            return que_embed, label_zeros

        for _ in range(epcoh):
            while True:
                file_csv = open(path, "r", encoding="utf-8")
                cout_all_line = 0
                cnt = 0
                x, y = [], []
                # 跳出循环
                if len_ql < cout_all_line:
                    break
                for line in file_csv:
                    cout_all_line += 1
                    if cout_all_line > 1:  # 第一条是标签'label,ques',不选择
                        x_line, y_line = process_line(line)
                        x.append(x_line)
                        y.append(y_line)
                        cnt += 1
                        if cnt == batch_size:
                            if embedding_type in ['bert', 'albert']:
                                x_, y_ = np.array(x), np.array(y)
                                x_1 = np.array([x[0] for x in x_])
                                x_2 = np.array([x[1] for x in x_])
                                x_all = [x_1, x_2]
                            elif embedding_type == 'xlnet':
                                x_, y_ = x, np.array(y)
                                x_1 = np.array([x[0][0] for x in x_])
                                x_2 = np.array([x[1][0] for x in x_])
                                x_3 = np.array([x[2][0] for x in x_])
                                x_all = [x_1, x_2, x_3]
                            else:
                                x_all, y_ = np.array(x), np.array(y)

                            cnt = 0
                            yield (x_all, y_)
                            x, y = [], []
            file_csv.close()
        print("preprocess_label_ques_to_idx ok")
    def preprocess_label_ques_to_idx(self,
                                     embedding_type,
                                     batch_size,
                                     path,
                                     embed,
                                     rate=1,
                                     epcoh=20):
        label_set, len_all = self.preprocess_get_label_set(path)
        # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用
        if not os.path.exists(self.path_fast_text_model_l2i_i2l):
            count = 0
            label2index = {}
            index2label = {}
            for label_one in label_set:
                label2index[label_one] = count
                index2label[count] = label_one
                count = count + 1

            l2i_i2l = {}
            l2i_i2l['l2i'] = label2index
            l2i_i2l['i2l'] = index2label
            save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l)
        else:
            l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l)

        # 读取数据的比例
        len_ql = int(rate * len_all)
        if len_ql <= 500:  # sample时候不生效,使得语料足够训练
            len_ql = len_all

        def process_line(line):
            # 对每一条数据操作,获取label和问句index
            data = json.loads(line)
            label = data['label']
            ques_1 = data['sentence1']
            ques_2 = data['sentence2']
            offset = data['offset']
            mention_1 = data["mention"]
            offset_i = int(offset)
            que_embed_1 = embed.sentence2idx(text=ques_1)
            que_embed_2 = embed.sentence2idx(text=ques_2)
            """ques1"""
            [input_id_1, input_type_id_1, input_mask_1] = que_embed_1
            input_start_mask_1 = [0] * len(input_id_1)
            input_start_mask_1[offset_i] = 1
            input_end_mask_1 = [0] * len(input_id_1)
            input_end_mask_1[offset_i + len(mention_1) - 1] = 1
            input_entity_mask_1 = [0] * len(input_id_1)
            input_entity_mask_1[offset_i:offset_i +
                                len(mention_1)] = [1] * len(mention_1)
            """ques2"""
            [input_id_2, input_type_id_2, input_mask_2] = que_embed_2
            kind_2 = [0] * len(input_type_id_2)
            kind_21 = [0] * len(input_type_id_2)
            que_2_sp = ques_2.split("|")
            if len(que_2_sp) >= 2:
                que_2_sp_sp = que_2_sp[0].split(":")
                if len(que_2_sp_sp) == 2:
                    kind_2_start = len(que_2_sp_sp[0]) - 1
                    kind_2_end = kind_2_start + len(que_2_sp_sp[1]) - 1
                    kind_2[kind_2_start:kind_2_end] = [1] * (kind_2_end -
                                                             kind_2_start)
                if "标签:" in que_2_sp[1]:
                    que_21_sp_sp = que_2_sp[1].split(":")
                    kind_21_start = len(que_2_sp[0]) + len(que_21_sp_sp[0]) - 1
                    kind_21_end = len(que_2_sp[0]) + len(
                        que_21_sp_sp[0]) + len(que_21_sp_sp[1]) - 1
                    kind_21[kind_21_start:kind_21_end] = [1] * (kind_21_end -
                                                                kind_21_start)
            que_embed_x = [
                input_id_1, input_type_id_1, input_mask_1, input_start_mask_1,
                input_end_mask_1, input_entity_mask_1, input_id_2,
                input_type_id_2, input_mask_2, kind_2, kind_21
            ]
            label_zeros = [0] * len(l2i_i2l['l2i'])
            label_zeros[l2i_i2l['l2i'][label]] = 1
            return que_embed_x, label_zeros

        for _ in range(epcoh):
            while True:
                file_csv = open(path, "r", encoding="utf-8")
                cout_all_line = 0
                cnt = 0
                x, y = [], []
                # 跳出循环
                if len_ql < cout_all_line:
                    break
                for line in file_csv:
                    cout_all_line += 1
                    x_line, y_line = process_line(line)
                    x.append(x_line)
                    y.append(y_line)
                    cnt += 1
                    if cnt == batch_size:
                        if embedding_type in ['bert', 'albert']:
                            x_, y_ = np.array(x), np.array(y)
                            x_all = []
                            for i in range(len(x_[0])):
                                x_1 = np.array([x[i] for x in x_])
                                x_all.append(x_1)
                        elif embedding_type == 'xlnet':
                            x_, y_ = x, np.array(y)
                            x_1 = np.array([x[0][0] for x in x_])
                            x_2 = np.array([x[1][0] for x in x_])
                            x_3 = np.array([x[2][0] for x in x_])
                            x_all = [x_1, x_2, x_3]
                        else:
                            x_all, y_ = np.array(x), np.array(y)

                        cnt = 0
                        yield (x_all, y_)
                        x, y = [], []
                file_csv.close()
        print("preprocess_label_ques_to_idx ok")
    def preprocess_label_ques_to_idx_old(self,
                                         embedding_type,
                                         batch_size,
                                         path,
                                         embed,
                                         rate=1,
                                         epcoh=20):
        label_set, len_all = self.preprocess_get_label_set(path)
        # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用
        if not os.path.exists(self.path_fast_text_model_l2i_i2l):
            count = 0
            label2index = {}
            index2label = {}
            for label_one in label_set:
                label2index[label_one] = count
                index2label[count] = label_one
                count = count + 1

            l2i_i2l = {}
            l2i_i2l['l2i'] = label2index
            l2i_i2l['i2l'] = index2label
            save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l)
        else:
            l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l)

        # 读取数据的比例
        len_ql = int(rate * len_all)
        if len_ql <= 500:  # sample时候不生效,使得语料足够训练
            len_ql = len_all

        def process_line(line):
            # 对每一条数据操作,获取label和问句index
            data = json.loads(line)
            label = data['label']
            ques_1 = data['sentence1']
            ques_2 = data['sentence2']
            offset = data['offset']
            mention = data["mention"]
            offset_i = int(offset)
            # if data.get("label_l2i"):
            #     ques_entity = data.get("label_l2i") + "#" + ques_1[:offset_i] + "#" + mention + "#" + ques_1[offset_i+len(mention):]
            # else:
            #     ques_entity = ques_1[:offset_i] + "#" + mention + "#" + ques_1[offset_i+len(mention):] + "$$" + ques_2
            # que_embed = embed.sentence2idx(text=ques_entity)
            que_embed = embed.sentence2idx(ques_1, second_text=ques_2)
            label_zeros = [0] * len(l2i_i2l['l2i'])
            label_zeros[l2i_i2l['l2i'][label]] = 1
            return que_embed, label_zeros

        for _ in range(epcoh):
            while True:
                file_csv = open(path, "r", encoding="utf-8")
                cout_all_line = 0
                cnt = 0
                x, y = [], []
                # 跳出循环
                if len_ql < cout_all_line:
                    break
                for line in file_csv:
                    cout_all_line += 1
                    x_line, y_line = process_line(line)
                    x.append(x_line)
                    y.append(y_line)
                    cnt += 1
                    if cnt == batch_size:
                        if embedding_type in ['bert', 'albert']:
                            x_, y_ = np.array(x), np.array(y)
                            x_1 = np.array([x[0] for x in x_])
                            x_2 = np.array([x[1] for x in x_])
                            x_all = [x_1, x_2]
                        elif embedding_type == 'xlnet':
                            x_, y_ = x, np.array(y)
                            x_1 = np.array([x[0][0] for x in x_])
                            x_2 = np.array([x[1][0] for x in x_])
                            x_3 = np.array([x[2][0] for x in x_])
                            x_all = [x_1, x_2, x_3]
                        else:
                            x_all, y_ = np.array(x), np.array(y)

                        cnt = 0
                        yield (x_all, y_)
                        x, y = [], []
                file_csv.close()
        print("preprocess_label_ques_to_idx ok")