Ejemplo n.º 1
0
    def _prepare_test_data(self, file_corpus=conf.TEST_FILE):
        print(
            '***************************prepare corpus***************************'
        )
        self.fit_gen = FitGeneratorWrapper(
            type=self.conf.type,
            file_vocab=self.conf.VOCAB_FILE,
            file_corpus=self.conf.TRAIN_FILE,
            batch_size=self.conf.batch_size,
            max_len=self.conf.maxlen,
            vector_dim=self.conf.vector_dim,
            ner_vocab=self.conf.pretrain_vocab,
            label_dict_file=self.conf.LABEL_DICT_FILE)
        self.vocab_size = self.fit_gen.get_vocab_size()
        self.labels_num = self.fit_gen.get_label_count()
        self.labels_num = self.fit_gen.get_label_count()

        self.x_test, self.y_test = self.fit_gen.read_corpus(
            file_corpus=file_corpus)

        labels = []
        for label in self.y_test:
            index = np.argmax(label)
            labels.append(index)

        return self.x_test, np.array(labels)
Ejemplo n.º 2
0
def file_infer_test(model_file):
    # fit_gen_test = FitGeneratorWrapper(type=conf.type, file_vocab=conf.VOCAB_FILE, file_corpus=conf.TEST_FILE,
    #                                    batch_size=conf.batch_size, max_len=conf.maxlen, vector_dim=conf.vector_dim,
    #                                    pretrain_vocab=conf.pretrain_vocab, label_dict_file=conf.LABEL_DICT_FILE)
    path_data = '../data/new_all.csv.test'

    list_data = []
    with open(path_data, 'r') as f:
        for i in f.read().splitlines():
            list_data.append(i.split('\t'))
    fit_gen_train = FitGeneratorWrapper(type=conf.type, file_vocab=conf.VOCAB_FILE, file_corpus=conf.TRAIN_FILE,
                                        batch_size=conf.batch_size, max_len=conf.maxlen, vector_dim=conf.vector_dim,
                                        ner_vocab=conf.pretrain_vocab, label_dict_file=conf.LABEL_DICT_FILE)
    vocab_size_train = fit_gen_train.get_vocab_size()
    sentences_train, labels_train = fit_gen_train.read_raw_corpus(file_corpus=conf.TRAIN_FILE)
    sentences_test, labels_test = fit_gen_train.read_raw_corpus(file_corpus=conf.TEST_FILE)

    model = ModelWrapper.model(conf, train=False, vocab_size=vocab_size_train, labels_num=0)
    model.load_weights(model_file, by_name=True)
    model.summary()
    vectors_train = __predict_vectors(model, sentences_train, conf.vector_dim)
    vectors_test= __predict_vectors(model, sentences_test, conf.vector_dim)

    dic_all, labels_list_after_set = gauss_change_data(vectors_train, labels_train)
    models = {}
    n_components = 3
    model_dir = "/Users/chenhengxi/PycharmProjects/work2/sentence-encoding-qa/data/model"
    for domain in labels_list_after_set:
        modelx = GaussianMixture(n_components=n_components, covariance_type='diag', reg_covar=0.0001, max_iter=200,
                                verbose=0, verbose_interval=1)
        data = np.array(dic_all[domain])
        modelx.fit(data)
        models[domain] = modelx
        joblib.dump(modelx, "{0}/{1}.joblib".format(model_dir,domain))
    final_dic = {}
    final_num=0
    error=[]
    for i in range(len(vectors_test)):
        print(i)
        accept_scores = {}
        for domain in labels_list_after_set:
            models[domain] = joblib.load("{0}/{1}.joblib".format(model_dir,domain))
            a=np.squeeze(vectors_test[i])
            #vectors_test[i]=a.reshape(-1, 1)
            point_array = models[domain].score_samples(a.reshape(1,conf.vector_dim))
            point = point_array[0]
            accept_scores[str(point)] = domain
        list_to_max = []
        for num in accept_scores:
            list_to_max.append(float(num))
        max_num = max(list_to_max)
        label_final = accept_scores[str(max_num)]
        final_dic[str(vectors_test[i])] = label_final
        if list_data[i][1]!=label_final:
            final_num+=1
            error.append([list_data[i][0],list_data[i][1],label_final])
    print((1-final_num/len(vectors_test)))
    print(error)
Ejemplo n.º 3
0
	def prepare_data(self):
		print('***************************prepare corpus***************************')
		self.fit_gen = FitGeneratorWrapper(type=self.type, file_vocab=conf.VOCAB_FILE, file_corpus=conf.TRAIN_FILE,
                                           batch_size=conf.batch_size, max_len=conf.maxlen, vector_dim=conf.vector_dim,
                                           ner_vocab=conf.pretrain_vocab, label_dict_file=conf.LABEL_DICT_FILE)
		self.vocab_size = self.fit_gen.get_vocab_size()
		self.corpus_size = self.fit_gen.get_line_count()
		self.labels_num = self.fit_gen.get_label_count()
		self.x_test, self.y_test = self.fit_gen.read_corpus(file_corpus=conf.TEST_FILE)
		if conf.FIT_GENERATE == False:
			self.x_train, self.y_train = self.fit_gen.read_corpus(file_corpus=conf.TRAIN_FILE)
Ejemplo n.º 4
0
 def _prepare_test_data(self, file_corpus=conf.PAIR_TEST_FILE):
     print(
         '***************************prepare corpus***************************'
     )
     self.fit_gen = FitGeneratorWrapper(
         type=self.conf.type,
         file_vocab=self.conf.VOCAB_FILE,
         file_corpus=self.conf.PAIR_TEST_FILE,
         batch_size=self.conf.batch_size,
         max_len=self.conf.maxlen,
         vector_dim=self.conf.vector_dim,
         ner_vocab=self.conf.pretrain_vocab,
         label_dict_file=self.conf.LABEL_DICT_FILE)
     self.vocab_size = self.fit_gen.get_vocab_size()
     self.labels_num = self.fit_gen.get_label_count()
     sentences1, sentences2, labels = self.fit_gen.read_raw_corpus(
         file_corpus=file_corpus)
     return [sentences1, sentences2], labels
Ejemplo n.º 5
0
class InferTextMatchTest(InferTest):
    def __init__(self):
        super(InferTextMatchTest, self).__init__(backbone="TEXT_MATCH",
                                                 type="pair",
                                                 header="value")

    def _prepare_test_data(self, file_corpus=conf.PAIR_TEST_FILE):
        print(
            '***************************prepare corpus***************************'
        )
        self.fit_gen = FitGeneratorWrapper(
            type=self.conf.type,
            file_vocab=self.conf.VOCAB_FILE,
            file_corpus=self.conf.PAIR_TEST_FILE,
            batch_size=self.conf.batch_size,
            max_len=self.conf.maxlen,
            vector_dim=self.conf.vector_dim,
            ner_vocab=self.conf.pretrain_vocab,
            label_dict_file=self.conf.LABEL_DICT_FILE)
        self.vocab_size = self.fit_gen.get_vocab_size()
        self.labels_num = self.fit_gen.get_label_count()
        sentences1, sentences2, labels = self.fit_gen.read_raw_corpus(
            file_corpus=file_corpus)
        return [sentences1, sentences2], labels

    def file_infer_test(self, model_file, values_file=None):
        sentences, labels = self._prepare_test_data()

        print(
            '***************************build model***************************'
        )
        model = ModelWrapper.model(self.conf,
                                   train=False,
                                   vocab_size=self.vocab_size,
                                   labels_num=1)
        model.summary()
        self._load_model(model_file, model)

        print(
            '***************************infer test***************************')
        # if os.path.exists(values_file):
        #     print("load cache file " + values_file)
        #     values = np.load(values_file)
        # else:
        values = self.do_predict(model,
                                 sentences,
                                 vector_dim=self.conf.vector_dim,
                                 header=self.conf.predict_header,
                                 type=self.conf.type)
        print("save cache file " + values_file)
        np.save(values_file, values)

        tpr, fpr, accuracy, best_thresholds = evaluate_best_threshold_value(
            values, labels, nrof_folds=10)
        tpr = np.mean(tpr)
        fpr = np.mean(fpr)
        accuracy = np.mean(accuracy)
        best_thresholds = np.mean(best_thresholds)
        print(
            "cosine: (正样本的召回率tp/(tp+fn))tpr={} (负样本的错误率fp/(fp+tn))fpr={} acc={} threshold={}"
            .format(tpr, fpr, accuracy, best_thresholds))
        return best_thresholds
Ejemplo n.º 6
0
class InferClassTest(InferTest):
    def __init__(self):
        super(InferClassTest, self).__init__(backbone="CNN_CLASS",
                                             type="class",
                                             header="index")

    def _prepare_test_data(self, file_corpus=conf.TEST_FILE):
        print(
            '***************************prepare corpus***************************'
        )
        self.fit_gen = FitGeneratorWrapper(
            type=self.conf.type,
            file_vocab=self.conf.VOCAB_FILE,
            file_corpus=self.conf.TRAIN_FILE,
            batch_size=self.conf.batch_size,
            max_len=self.conf.maxlen,
            vector_dim=self.conf.vector_dim,
            ner_vocab=self.conf.pretrain_vocab,
            label_dict_file=self.conf.LABEL_DICT_FILE)
        self.vocab_size = self.fit_gen.get_vocab_size()
        self.labels_num = self.fit_gen.get_label_count()
        self.labels_num = self.fit_gen.get_label_count()

        self.x_test, self.y_test = self.fit_gen.read_corpus(
            file_corpus=file_corpus)

        labels = []
        for label in self.y_test:
            index = np.argmax(label)
            labels.append(index)

        return self.x_test, np.array(labels)

    def _get_num_2(self, values, first):
        reset_values = []
        for i in range(len(values)):
            if i != first:
                reset_values.append(values[i])
            else:
                reset_values.append(0)
        reset_values = np.array(reset_values)
        value2 = np.max(reset_values)
        index2 = np.argmax(reset_values)
        return value2, index2

    def _class_predict_index(self, model, sentences):
        indexs = np.zeros((sentences.shape[0]), dtype=np.int)
        values = np.zeros((sentences.shape[0]), dtype=np.float)
        indexs2 = np.zeros((sentences.shape[0]), dtype=np.int)
        values2 = np.zeros((sentences.shape[0]), dtype=np.float)
        all_values = []
        max_len = sentences.shape[1]
        i = 0
        start_t = datetime.datetime.now()
        for sentence in sentences:
            sentence = sentence.reshape(1, max_len)
            logits_output = model.predict(sentence)
            squeeze_array = np.squeeze(logits_output)
            all_values.append(squeeze_array)
            values[i] = np.max(squeeze_array)
            indexs[i] = np.argmax(squeeze_array)
            value2, index2 = self._get_num_2(squeeze_array, indexs[i])
            values2[i] = value2
            indexs2[i] = index2
            i += 1
        end_t = datetime.datetime.now()
        print("{} sentences infer time {} seconds".format(
            sentences.shape[0], (end_t - start_t).seconds))
        return indexs, values, all_values, indexs2, values2

    def file_infer_test(self, model_file, values_file=None):
        sentences, labels = self._prepare_test_data()

        print(
            '***************************build model***************************'
        )
        model = ModelWrapper.model(self.conf,
                                   train=False,
                                   vocab_size=self.vocab_size,
                                   labels_num=self.labels_num)
        model.summary()
        self._load_model(model_file, model)

        print(
            '***************************infer test***************************')
        indexs, values, all_values, indexs2, values2 = self.do_predict(
            model,
            sentences,
            vector_dim=self.conf.vector_dim,
            header=self.conf.predict_header,
            type=self.conf.type)

        correct_num = 0
        for i in range(len(indexs)):
            if indexs[i] != labels[i]:
                labels[i] = 0
            else:
                labels[i] = 1
                correct_num += 1

        print("validat set precise {}, error number {}".format(
            correct_num / len(labels),
            len(labels) - correct_num))

        tpr, fpr, accuracy, best_thresholds = evaluate_best_threshold_value(
            values, labels, nrof_folds=10)
        tpr = np.mean(tpr)
        fpr = np.mean(fpr)
        accuracy = np.mean(accuracy)
        best_thresholds = np.mean(best_thresholds)
        print(
            "cosine: (正样本的召回率tp/(tp+fn))tpr={} (负样本的错误率fp/(fp+tn))fpr={} acc={} threshold={}"
            .format(tpr, fpr, accuracy, best_thresholds))

    def find_best_threshold_for_second(self, model_file):
        sentences, labels = self._prepare_test_data()

        print(
            '***************************build model***************************'
        )
        model = ModelWrapper.model(self.conf,
                                   train=False,
                                   vocab_size=self.vocab_size,
                                   labels_num=self.labels_num)
        model.summary()
        self._load_model(model_file, model)

        print(
            '***************************infer test***************************')
        indexs, values, all_values, indexs2, values2 = self.do_predict(
            model,
            sentences,
            vector_dim=self.conf.vector_dim,
            header=self.conf.predict_header,
            type=self.conf.type)

        ground_truth = labels.copy()
        correct_num = 0
        err_min_second = 1.0  #错误的case中第二个的最小值
        err_max_gap = 0.0  #错误的case中第一个和第二个的最大差距
        correct_max_second = 0.0  #正确的case中第二个的最大值
        correct_min_gap = 1.0  #正确的case中第一个和第二个的最小差距
        for i in range(len(indexs)):
            if indexs[i] != labels[i]:
                labels[i] = 0
                if err_min_second > values2[i]:
                    err_min_second = values2[i]
                if err_max_gap < values[i] - values2[i]:
                    err_max_gap = values[i] - values2[i]
            else:
                labels[i] = 1
                correct_num += 1
                if correct_max_second < values2[i]:
                    correct_max_second = values2[i]
                if correct_min_gap > values[i] - values2[i]:
                    correct_min_gap = values[i] - values2[i]

        print("validat set precise {}, error number {}".format(
            correct_num / len(labels),
            len(labels) - correct_num))
        print("err_min_second {}, err_max_gap {}".format(
            err_min_second, err_max_gap))
        print("correct_max_second {}, correct_min_gap {}".format(
            correct_max_second, correct_min_gap))

        best_gap, best_gap_rate = self.evaluate_gaps(ground_truth, indexs,
                                                     values, indexs2, values2,
                                                     correct_min_gap,
                                                     err_max_gap)
        print("best_gap {}, best_gap_rate {}".format(best_gap, best_gap_rate))

        best_threshold, best_threshold_rate = self.evaluate_second_thresholds(
            ground_truth, indexs, values2, err_min_second, correct_max_second)
        print("best_threshold {}, best_threshold_rate {}".format(
            best_threshold, best_threshold_rate))

    def evaluate_gaps(self, ground_truth, indexs, values, indexs2, values2,
                      gap1, gap2):
        num = int((gap2 - gap1) / 0.005)
        gaps = np.linspace(gap1, gap2 + 0.005, num=num)
        recall_rate = []
        for gap_idx, gap in enumerate(gaps):
            rate = self.evaluate_gap(ground_truth, indexs, values, values2,
                                     gap)
            recall_rate.append(rate)
        best_index = recall_rate.index(max(recall_rate))
        return gaps[best_index], recall_rate[best_index]

    def evaluate_gap(self, ground_truth, indexs, values, values2, gap):
        err_indexs = []
        for i in range(len(ground_truth)):
            if indexs[i] != ground_truth[i]:
                err_indexs.append(i)

        pickup_list = []
        for i in range(len(ground_truth)):
            if values[i] - values2[i] < gap:
                pickup_list.append(i)

        print("gap={} pickup={} err={}".format(gap, pickup_list, err_indexs))
        if set(err_indexs) <= set(pickup_list):
            return len(err_indexs) / len(pickup_list)
        else:
            return 0.0

    def evaluate_second_thresholds(self, ground_truth, indexs, values2,
                                   threshold1, threshold2):
        num = int((threshold2 - threshold1) / 0.005)
        thresholds = np.linspace(threshold1, threshold2 + 0.005, num=num)
        recall_rate = []
        for threshold_idx, threshold in enumerate(thresholds):
            rate = self.evaluate_second_threshold(ground_truth, indexs,
                                                  values2, threshold)
            recall_rate.append(rate)
        best_index = recall_rate.index(max(recall_rate))
        return thresholds[best_index], recall_rate[best_index]

    def evaluate_second_threshold(self, ground_truth, indexs, values2,
                                  threshold):
        err_indexs = []
        for i in range(len(ground_truth)):
            if indexs[i] != ground_truth[i]:
                err_indexs.append(i)

        pickup_list = []
        for i in range(len(ground_truth)):
            if values2[i] > threshold:
                pickup_list.append(i)

        print("gap={} pickup={} err={}".format(threshold, pickup_list,
                                               err_indexs))
        if set(err_indexs) <= set(pickup_list):
            return len(err_indexs) / len(pickup_list)
        else:
            return 0.0
Ejemplo n.º 7
0
class TextMatchTrain(TrainBase):
    def __init__(self, backbone="TEXT_MATCH", type="pair"):
        super(TextMatchTrain, self).__init__(backbone=backbone, type=type)

    def prepare_data(self):
        print(
            '***************************prepare corpus***************************'
        )
        self.fit_gen = FitGeneratorWrapper(
            type=self.type,
            file_vocab=conf.VOCAB_FILE,
            file_corpus=conf.PAIR_TRAIN_FILE,
            batch_size=conf.batch_size,
            max_len=conf.maxlen,
            vector_dim=conf.vector_dim,
            ner_vocab=conf.pretrain_vocab,
            label_dict_file=conf.LABEL_DICT_FILE)
        self.vocab_size = self.fit_gen.get_vocab_size()
        self.orpus_size = self.fit_gen.get_line_count()

        self.x_test1, self.x_test2, self.y_test = self.fit_gen.read_corpus(
            file_corpus=conf.PAIR_TEST_FILE)
        if conf.FIT_GENERATE == False:
            self.x_train1, self.x_train2, self.y_train = self.fit_gen.read_corpus(
                file_corpus=conf.PAIR_TEST_FILE)

    def build(self):
        print(
            '***************************build model***************************'
        )
        self.model = ModelWrapper.model(conf,
                                        train=True,
                                        vocab_size=self.vocab_size,
                                        labels_num=1)

        self.model.compile('adam', 'mse', metrics=['accuracy'])

        self.model.summary()

    def do_train(self):
        print(
            "***************************start training***************************"
        )
        save_callback = SaveCallback(
            save_path=conf.SAVE_DIR,
            backbone=conf.backbone,
            model=self.model,
            timestamp=self.timestamp,
            save_name=self.save_name)  # , validation_data=[x_test, y_test])
        early_stop_callback = callbacks.EarlyStopping(
            monitor='val_loss',
            restore_best_weights=True,
            patience=conf.early_stop_patience,
            verbose=1,
            mode='auto')
        reduce_lr_callback = callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=conf.reduce_lr_factor,
            patience=conf.reduce_lr_patience,
            verbose=1,
            mode='auto',
            epsilon=0.0001,
            cooldown=0,
            min_lr=0.00001)
        tensorboard_callback = TensorBoard(log_dir=conf.OUT_DIR)

        callbacks_list = []
        callbacks_list.append(save_callback)
        callbacks_list.append(early_stop_callback)
        callbacks_list.append(reduce_lr_callback)
        callbacks_list.append(tensorboard_callback)

        if conf.FIT_GENERATE == True:
            self.model.fit(self.fit_gen.generate(),
                           epochs=conf.epochs,
                           steps_per_epoch=self.corpus_size / conf.batch_size,
                           callbacks=callbacks_list,
                           validation_data=([self.x_test1,
                                             self.x_test2], self.y_test),
                           verbose=1)
        else:
            self.model.fit(x=[self.x_train1, self.x_train2],
                           y=self.y_train,
                           batch_size=conf.batch_size,
                           epochs=conf.epochs,
                           callbacks=callbacks_list,
                           validation_data=([self.x_test1,
                                             self.x_test2], self.y_test),
                           verbose=1)

    def post_test(self):
        print("no test")
Ejemplo n.º 8
0
class TrainBase(object):
	def __init__(self, backbone = "CNN_ARCFACE", type="class", loss='categorical_crossentropy', optimizer='adam', metrics=["accuracy"], gan_enable=False):
		conf.backbone=backbone
		self.type=type
		if gan_enable == True:
			self.loss= self.gan_loss_with_gradient_penalty
		else:
			self.loss=loss
		self.optimizer=optimizer
		self.metrics=metrics

		print(os.getcwd() + "/../")
		sys.path.append(os.getcwd() + "/../")

		if len(sys.argv) == 2 and sys.argv[1] == "gpu":
			if tf.test.is_gpu_available():
				os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 0 for V100, 1 for P100
		else:
			os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

		self.timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
		if conf.attention_enable == False:
			self.save_name = "{}{}_{}.h5".format(conf.SAVE_DIR, conf.backbone, self.timestamp)
		else:
			self.save_name = "{}{}_{}_{}.h5".format(conf.SAVE_DIR, conf.backbone, "ATTENTION", self.timestamp)

	def focal_loss(self, gamma=2., alpha=1 ):
		def focal_loss_fixed(y_true, y_pred):
			pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
			pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
			return -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))
		return focal_loss_fixed

	def gan_loss_with_gradient_penalty(self, y_true, y_pred, epsilon=2):
		"""带梯度惩罚的loss
        """
		loss = K.mean(K.categorical_crossentropy(y_true, y_pred))
		# 查找Embedding层
		for output in self.model.outputs:
			embedding_layer = search_layer(output, "embedding")
			if embedding_layer is not None:
				break
		if embedding_layer is None:
			raise Exception('Embedding layer not found')

		embeddings = embedding_layer.embeddings
		gp = K.sum(K.gradients(loss, [embeddings])[0].values ** 2)
		return loss + 0.5 * epsilon * gp


	def prepare_data(self):
		print('***************************prepare corpus***************************')
		self.fit_gen = FitGeneratorWrapper(type=self.type, file_vocab=conf.VOCAB_FILE, file_corpus=conf.TRAIN_FILE,
                                           batch_size=conf.batch_size, max_len=conf.maxlen, vector_dim=conf.vector_dim,
                                           ner_vocab=conf.pretrain_vocab, label_dict_file=conf.LABEL_DICT_FILE)
		self.vocab_size = self.fit_gen.get_vocab_size()
		self.corpus_size = self.fit_gen.get_line_count()
		self.labels_num = self.fit_gen.get_label_count()
		self.x_test, self.y_test = self.fit_gen.read_corpus(file_corpus=conf.TEST_FILE)
		if conf.FIT_GENERATE == False:
			self.x_train, self.y_train = self.fit_gen.read_corpus(file_corpus=conf.TRAIN_FILE)


	def build(self):
		print('***************************build model***************************')
		self.model = ModelWrapper.model(conf, train=True, vocab_size=self.vocab_size, labels_num=self.labels_num)
		self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=self.metrics)  # optimizer=Adam()  keras.optimizers.Adam(lr=args.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
		# model.compile(loss=focal_loss(), optimizer='adam', metrics=["accuracy"])   #optimizer=Adam()  keras.optimizers.Adam(lr=args.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-8)

		self.model.summary()

	def do_train(self):
		print("***************************start training***************************")
		save_callback = SaveCallback(save_path=conf.SAVE_DIR, backbone=conf.backbone, model=self.model,
									 timestamp=self.timestamp, save_name=self.save_name)  # , validation_data=[x_test, y_test])
		early_stop_callback = callbacks.EarlyStopping(monitor='val_loss', patience=conf.early_stop_patience, verbose=1, mode='auto', restore_best_weights=True)
		reduce_lr_callback = callbacks.ReduceLROnPlateau(monitor='val_acc', factor=conf.reduce_lr_factor, patience=conf.reduce_lr_patience, verbose=1,
														 mode='auto', epsilon=0.0001, cooldown=0, min_lr=0.00001)
		tensorboard_callback = TensorBoard(log_dir=conf.OUT_DIR)

		callbacks_list = []
		callbacks_list.append(save_callback)
		callbacks_list.append(early_stop_callback)
		#callbacks_list.append(reduce_lr_callback)
		callbacks_list.append(tensorboard_callback)

		if conf.FIT_GENERATE == True:
			self.model.fit(self.fit_gen.generate(),
						   epochs=conf.epochs,
						   steps_per_epoch=self.corpus_size / conf.batch_size,
					  	   callbacks=callbacks_list,
					  	   validation_data=([self.x_test, self.y_test], self.y_test), verbose=1)
		else:
			self.model.fit(x=[self.x_train,
							  self.y_train],
						   	  y=self.y_train,
						      batch_size=conf.batch_size,
						      epochs=conf.epochs,
					  		  callbacks=callbacks_list,
					          validation_data=([self.x_test, self.y_test], self.y_test),  # validation_split=0.02,
					          verbose=1)
		print("***************************train done***************************")

	def post_test(self):
		print("***************************start infer test***************************")
		infer = "cd " + os.getcwd() + "/../infer;python " + os.getcwd() + "/../infer/infer_test.py " + self.save_name
		os.system(infer)


	def save(self):
		print("save to {}".format(self.save_name))
		if not os.path.exists(conf.SAVE_DIR):
			os.mkdir(conf.SAVE_DIR)
		self.model.save(self.save_name)

		print("***************************save done***************************")