Пример #1
0
def test():
    print("Loading test data...")
    start_time = time.time()
    x_test, y_test = process_file(test_dir, word_to_id, cat_to_id,
                                  config.seq_length)

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)  # 读取保存的模型

    print('Testing...')
    loss_test, acc_test = evaluate(session, x_test, y_test)
    msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    print(msg.format(loss_test, acc_test))

    batch_size = 128
    data_len = len(x_test)
    num_batch = int((data_len - 1) / batch_size) + 1

    y_test_cls = np.argmax(y_test, 1)
    y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
    for i in range(num_batch):  # 逐批次处理
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        feed_dict = {
            model.input_x: x_test[start_id:end_id],
            model.keep_prob: 1.0
        }
        y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls,
                                                  feed_dict=feed_dict)

    # 评估
    print("Precision, Recall and F1-Score...")
    print(
        metrics.classification_report(y_test_cls,
                                      y_pred_cls,
                                      target_names=categories))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)
def test(rnn):
    start_time = time.time()
    word2id, label2id, labels = build_vocab_labels(FLAGS.train_data_path)
    x_test, y_test = process_file(FLAGS.test_data_path, word2id, label2id,
                                  rnn.seq_length)

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(session, save_path=FLAGS.save_path)

    loss_test, acc_test = evaluate(rnn, session, x_test, y_test)
    msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    print(msg.format(loss_test, acc_test))

    batch_size = rnn.batch_size
    data_len = len(x_test)
    num_batch = int((data_len - 1) / batch_size) + 1

    y_test_cls = np.argmax(y_test, 1)
    y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)

    for i in range(num_batch):
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        feed_dict = {
            rnn.input_x: x_test[start_id:end_id],
            rnn.input_y: y_test[start_id:end_id],
            rnn.keep_prob: 1.0
        }
        y_pred_cls[start_id:end_id] = session.run(rnn.y_pred_cls,
                                                  feed_dict=feed_dict)

    # 评估
    print('Precision, Recall and F1-score...')
    print(
        metrics.classification_report(y_test_cls,
                                      y_pred_cls,
                                      target_names=labels))

    # 混淆矩阵
    print("Confusion Matrix")
    print(metrics.confusion_matrix(y_test_cls, y_pred_cls))

    print('Time usage: {}'.format(get_time_dif(start_time)))
Пример #3
0
def train():
    print("Configuring TensorBoard and Saver...")
    # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
    tensorboard_dir = 'tensorboard/textcnn'
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    tf.summary.scalar("loss", model.loss)
    tf.summary.scalar("accuracy", model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)

    # 配置 Saver
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    print("Loading training and validation data...")
    start_time = time.time()
    x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,
                                    config.seq_length)
    x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,
                                config.seq_length)
    # 载入训练集与验证集
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)
    # 创建session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    writer.add_graph(session.graph)
    print('Training and evaluating...')
    start_time = time.time()
    total_batch = 0  # 总批次
    best_acc_val = 0.0  # 最佳验证集准确率
    last_improved = 0  # 记录上一次提升批次
    require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练
    flag = False
    for epoch in range(config.num_epochs):
        print('Epoch:', epoch + 1)
        batch_train = batch_iter(x_train, y_train, config.batch_size)
        for x_batch, y_batch in batch_train:
            feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
            #print("x_batch is {}".format(x_batch.shape))
            s = session.run(merged_summary, feed_dict=feed_dict)
            writer.add_summary(s, total_batch)
            if total_batch % config.print_per_batch == 0:
                # 每多少轮次输出在训练集和验证集上的性能
                feed_dict[model.keep_prob] = 1.0
                loss_train, acc_train = session.run([model.loss, model.acc],
                                                    feed_dict=feed_dict)
                loss_val, acc_val = evaluate(session, x_val, y_val)  # todo

                if acc_val > best_acc_val:
                    # 保存最好结果
                    best_acc_val = acc_val
                    last_improved = total_batch
                    saver.save(sess=session, save_path=save_path)
                    improved_str = '*'
                else:
                    improved_str = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                      + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                print(
                    msg.format(total_batch, loss_train, acc_train, loss_val,
                               acc_val, time_dif, improved_str))

            session.run(model.optim, feed_dict=feed_dict)  # 运行优化
            total_batch += 1

            if total_batch - last_improved > require_improvement:
                # 验证集正确率长期不提升,提前结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break  # 跳出循环
        if flag:  # 同上
            break
Пример #4
0
import data_helper as DH
from textCNN import TextCNN
import importlib, sys
importlib.reload(sys)
"""
对模型进行训练
"""

print("load train and val data sets.....")
#加载数据,以及计算相关参数
categories, cat_to_id = DH.Get_Categories()
num_classes = len(categories)
words, word_to_id = DH.read_vocab(DH.TextConfig.vocab_filename)
vocab_size = len(words) + 1
x_train, y_train = DH.process_file(DH.TextConfig.trainFileCsv, word_to_id,
                                   cat_to_id, num_classes,
                                   DH.TextConfig.seq_length)
x_val, y_val = DH.process_file(DH.TextConfig.val_filename, word_to_id,
                               cat_to_id, num_classes,
                               DH.TextConfig.seq_length)
pad_seq_len = len(x_train[0])

num_epochs = 40
batch_size = 256
checkpoint_every = int(
    int(math.ceil(len(x_train) / batch_size)) * (num_epochs / 10))

# Data Parameters

# Model Hyperparameters
tf.flags.DEFINE_float("learning_rate", 0.001,
def train(filename=None):
    print("Loading training and validation data...")
    # 载入训练集与验证集
    start_time = time.time()
    if filename is not None:
        df = process_train_raw_file(filename)
    elif os.path.exists("data/train_test_files/qttnews.train.csv"):
        df = pd.read_csv("data/train_test_files/qttnews.train.csv",
                         header=0,
                         index_col=None)
    elif os.path.exists("data/raw_data/model_data_2019-1-23_data2.txt"):
        df = process_train_raw_file(
            "data/raw_data/model_data_2019-1-23_data2.txt")
    title_pad, content_pad, keyword_pad, auxilary, y_pad = process_file(
        df,
        word_to_id,
        title_max_length=20,
        content_max_length=6000,
        keyword_max_length=8)
    # Randomly shuffle data
    np.random.seed(10)
    shuffle_indices = np.random.permutation(np.arange(len(y_pad)))
    title_pad, content_pad, keyword_pad, auxilary, y_pad = title_pad[
        shuffle_indices], content_pad[shuffle_indices], keyword_pad[
            shuffle_indices], auxilary[shuffle_indices], y_pad[shuffle_indices]
    # Split train/val set
    # TODO: This is very crude, should use cross-validation
    dev_sample_index = int(dev_sample_percentage * float(len(y_pad)))
    title_train, content_train, keyword_train, auxilary_train, y_train = title_pad[
        dev_sample_index:], content_pad[dev_sample_index:], keyword_pad[
            dev_sample_index:], auxilary[dev_sample_index:], y_pad[
                dev_sample_index:]
    title_val, content_val, keyword_val, auxilary_val, y_val = title_pad[:
                                                                         dev_sample_index], content_pad[:
                                                                                                        dev_sample_index], keyword_pad[:
                                                                                                                                       dev_sample_index], auxilary[:
                                                                                                                                                                   dev_sample_index], y_pad[:
                                                                                                                                                                                            dev_sample_index]
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)
    print("begin the training")
    # Training
    # ==================================================

    print("Configuring TensorBoard and Saver...")
    # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
    tensorboard_dir = 'tensorboard/textcnn'
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    else:
        for the_file in os.listdir(tensorboard_dir):
            file_path = os.path.join(tensorboard_dir, the_file)
            try:
                if os.path.isfile(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print(e)
    tf.summary.scalar("loss", model.loss)
    tf.summary.scalar("accuracy", model.acc)
    # tf.summary.scalar("auc", model.auc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)
    # 配置 Saver
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        # 创建session

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    # session.run(tf.initialize_local_variables())
    writer.add_graph(sess.graph)
    if tf.train.checkpoint_exists(save_path):
        print("restoring the model")
        saver.restore(sess=sess, save_path=save_path)  # 读取保存的模型

    # if not config.disable_word_embeddings:
    #     if os.path.exists(word_vector_dir):
    #         initW = pd.read_csv(word_vector_dir,header=None,index_col=None).values
    #         sess.run(model.embedding.assign(initW))

    print('Training and evaluating...')
    start_time = time.time()
    total_batch = 0  # 总批次
    best_acc_val = 0.0  # 最佳验证集准确率
    # best_auc_val = 0.0  # 最佳验证集准确率
    last_improved = 0  # 记录上一次提升批次
    require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练

    flag = False
    for epoch in range(config.num_epochs):
        print('Epoch:', epoch + 1)
        batch_train = batch_iter(title_train, content_train, keyword_train,
                                 auxilary_train, y_train, config.batch_size)
        for x_title_batch, x_content_batch, x_keyword_batch, x_auxilary_batch, y_batch in batch_train:
            feed_dict = feed_data(x_title_batch, x_content_batch,
                                  x_keyword_batch, x_auxilary_batch, y_batch,
                                  config.dropout_keep_prob)

            _, global_step, train_summaries, train_loss, train_accuracy = sess.run(
                [
                    model.optim, model.global_step, merged_summary, model.loss,
                    model.acc
                ],
                feed_dict=feed_dict)
            if global_step % config.save_per_batch == 0:
                # 每多少轮次将训练结果写入tensorboard scalar
                s = sess.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s, global_step)

            if global_step % config.print_per_batch == 0:
                # 每多少轮次输出在训练集和验证集上的性能
                feed_dict[model.keep_prob] = 1.0
                loss_train, acc_train = sess.run([model.loss, model.acc],
                                                 feed_dict=feed_dict)

                loss_val, acc_val = evaluate(sess, title_val, content_val,
                                             keyword_val, auxilary_val,
                                             y_val)  # todo

                if acc_val > best_acc_val:
                    # 保存最好结果
                    best_acc_val = acc_val
                    last_improved = global_step
                    saver.save(sess=sess, save_path=save_path)
                    improved_str = '*'
                else:
                    improved_str = ''

                time_dif = get_time_dif(start_time)
                # print("Time usage:", time_dif)
                msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                      + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                print(
                    msg.format(total_batch, loss_train, acc_train, loss_val,
                               acc_val, time_dif, improved_str))

            # sess.run(model.optim, feed_dict=feed_dict)  # 运行优化
            total_batch += 1
            # print(total_batch)
            if global_step - last_improved > require_improvement:
                # 验证集正确率长期不提升,提前结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break  # 跳出循环

        if flag:  # 同上
            break
        config.lr *= config.lr_decay
def test():
    print("Loading test data...")
    start_time = time.time()
    df = pd.read_csv(test_dir, header=0, index_col=None)
    x_title_test, x_content_test, x_keyword_test, x_auxilary_test, y_test = process_file(
        df, word_to_id, config.title_seq_length, config.content_seq_length,
        config.keyword_seq_length)

    # x_test, y_test = process_file_v1(test_dir, word_to_id, config.seq_length)
    # Testing
    # ==================================================
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    # session.run(tf.initialize_local_variables())
    saver = tf.train.Saver()
    saver.restore(sess=sess, save_path=save_path)  # 读取保存的模型

    print('Testing...')
    loss_test, acc_test = evaluate(sess, x_title_test, x_content_test,
                                   x_keyword_test, x_auxilary_test, y_test)
    msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    print(msg.format(loss_test, acc_test))

    batch_size = config.batch_size
    data_len = len(y_test)
    num_batch = int((data_len - 1) / batch_size) + 1

    y_test_cls = np.argmax(y_test, 1)
    y_pred_cls = np.zeros(shape=len(y_test), dtype=np.int32)  # 保存预测结果
    y_pred_prob = np.zeros(shape=len(y_test), dtype=np.int32)  # 保存预测结果
    for i in range(num_batch):  # 逐批次处理
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        feed_dict = {
            model.input_x_title: x_title_test[start_id:end_id],
            model.input_x_content: x_content_test[start_id:end_id],
            model.input_x_keyword: x_keyword_test[start_id:end_id],
            model.input_x_auxilary: x_auxilary_test[start_id:end_id],
            model.keep_prob: 1.0
        }
        y_pred_cls[start_id:end_id] = sess.run(model.y_pred_cls,
                                               feed_dict=feed_dict)
        y_pred_prob[start_id:end_id] = sess.run(model.y_pred_prob,
                                                feed_dict=feed_dict)

    # 评估
    print("Precision, Recall and F1-Score...")
    print(
        metrics.classification_report(
            y_test_cls,
            y_pred_cls,
            target_names=["low_quality", "high_quality"]))

    # 混淆矩阵
    print("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    print(cm)
    # 评估auc
    print("auc...")
    print(roc_auc_score(y_test_cls, y_pred_prob) if len(y_test_cls) > 0 else 0)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)
Пример #7
0
    def predict(self, json_message):
        # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
        df = self.parse_json_to_df(json_message)
        type_cat = [16, 1, 4, 42, 9, 7, 8, 6, 17, 19, 25, 5, 12, 13, 23, 10, 27, 15, 14, 18]
        rank_cat = [1, 2, 3, 4, 5]
        df_notype = df[~df['type'].isin(type_cat)][["content_id"]]
        df_norank = df[~df['rank'].isin(rank_cat)][["content_id"]]
        if len(df_notype):
            df_notype["score"] = 0
        if len(df_norank ):
            df_norank["score"] = 0
        df = df[df['type'].isin(type_cat)]
        df = df[df['rank'].isin(rank_cat)]


        #如果数据当中含有nan值,则取出这些content_id
        df_null_content_id = df[df.isnull().any(axis=1)][["content_id"]]
        if len(df_null_content_id):
            df_null_content_id["score"] = 0
        df = df[~df.isnull().any(axis=1)]

        #去除detail中的html标签
        df["detail"] = df['detail'].apply(get_content_from_htmlpage)
        # print(df['detail'])
        df['create_time'] = pd.to_datetime(df['create_time'])
        current_time = datetime.datetime.now()
        df["update_time"] =current_time
        df_entertain_sports = df[df.type.isin([6,13])][["content_id","create_time"]]
        if len(df_entertain_sports):
            df_entertain_sports["create_time"] = df_entertain_sports["create_time"].apply(lambda x: current_time-x)/np.timedelta64(1,'h')
            df_entertain_sports["create_time"] = df_entertain_sports["create_time"].apply(lambda x: 1.0 if x<=4 else 1.0/math.log(x,4) )
            df_entertain_sports_dict = dict(zip(df_entertain_sports["content_id"].tolist(), df_entertain_sports["create_time"].tolist()))
        # df['weekday'] = df['create_time'].dt.dayofweek
        # cal = calendar()
        # # print(df['create_time'])
        # holidays = cal.holidays(start=df['create_time'].min(), end=df['create_time'].max())
        # df['holiday'] = df['create_time'].isin(holidays).apply(lambda x: 1 if x else 0)
        df["publish_to_update_hour"] = (df['update_time'] - df['create_time']) / np.timedelta64(1, 'h')
        df_too_old = df[df["publish_to_update_hour"] > 36][["content_id"]]
        if len(df_too_old):
            df_too_old["score"] = 0
        df = df[df["publish_to_update_hour"] <= 36]
        df["publish_to_update_hour"] = df["publish_to_update_hour"].apply(lambda x: 1.0 if x <= 3 else math.log(x, 3))

        df.drop(labels=['create_time',"update_time"], axis=1, inplace=True)

        wc = word_cutter("data/dict")
        df['title_length'] = df['title'].apply(lambda x: len(wc.clean_html(x)))
        df['title'] = df['title'].apply(lambda x: wc.cut_words(x))
        df["title_token_length"] = df['title'].apply(lambda x: len(x.split(" ")))
        df['detail_length'] = df['detail'].apply(lambda x: len(wc.clean_html(x)))
        df['detail'] = df['detail'].apply(lambda x: wc.cut_words(x))
        df['detail_token_length'] = df['detail'].apply(lambda x: len(x.split(" ")))
        # print("after tokenization\n",df.describe(include="all"))
        df_short_length_content_id = df[((df['title_length'] <= 1) | (df['title_token_length'] <=1)) | ((df['detail_length'] <= 10)| (df['detail_token_length'] <=3))][["content_id"]]

        if len(df_short_length_content_id):
            df_short_length_content_id["score"] = 0
        df = df[((df['title_length'] > 1) & (df['title_token_length'] > 1)) & ((df['detail_length'] > 10)& (df['detail_token_length'] > 3))]
        df['detail_length'] = df['detail_length'].apply(lambda x: math.log2(x))

        print(df.info())
        print("begin to make predictions")
        title_test, content_test, keyword_test, auxilary_test = process_file(df, self.word_to_id,
                                                                                             self.config.title_seq_length,
                                                                                             self.config.content_seq_length,
                                                                                             self.config.keyword_seq_length,test=True)

        feed_dict = {
            self.model.input_x_title: title_test,
            self.model.input_x_content: content_test,
            self.model.input_x_keyword: keyword_test,
            self.model.input_x_auxilary: auxilary_test,
            self.model.keep_prob: 1.0
        }

        y_pred_prob = self.session.run(self.model.y_pred_prob, feed_dict=feed_dict)
        print(y_pred_prob)
        df["score"] = pd.Series(y_pred_prob)

        df_result = df[["content_id","score"]]
        if len(df_notype):
            df_result = pd.concat([df_result,df_notype],axis=0)
        if len(df_norank):
            df_result = pd.concat([df_result,df_norank],axis=0)
        if len(df_null_content_id):
            df_result = pd.concat([df_result,df_null_content_id],axis=0)
        if len(df_too_old):
            df_result = pd.concat([df_result,df_too_old],axis=0)
        if len(df_short_length_content_id):
            df_result = pd.concat([df_result,df_short_length_content_id],axis=0)
        if len(df_entertain_sports):
            for index,row in df_result.iterrows():
                if row["content_id"] in df_entertain_sports_dict.keys():
                    df_result.loc[index, "score"] = df_entertain_sports_dict.get(row['content_id'])* df_result.loc[index, "score"]
        df_result["score"] = df_result["score"]*1000
        print(df_result.describe())
        tf.reset_default_graph()
        return df_result.to_json(orient='records')
def train(rnn):
    if not os.path.exists(FLAGS.tensorboard_path):
        os.makedirs(FLAGS.tensorboard_path)

    tf.summary.scalar('loss', rnn.loss)
    tf.summary.scalar('accuracy', rnn.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(FLAGS.tensorboard_path)

    saver = tf.train.Saver()
    if not os.path.exists(FLAGS.save_path):
        os.makedirs(FLAGS.save_path)

    # 载入训练集
    start_time = time.time()
    word2id, label2id, labels = build_vocab_labels(FLAGS.train_data_path)
    x_train, y_train = process_file(FLAGS.train_data_path, word2id, label2id,
                                    rnn.seq_length)
    x_test, y_test = process_file(FLAGS.test_data_path, word2id, label2id,
                                  rnn.seq_length)

    # 创建session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    writer.add_graph(session.graph)

    total_batch = 0  # 总批次
    best_acc_test = 0.0  # 最佳准确率

    # Statistic the number of parameters
    print("the number of parameters is: {}".format(
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])))

    for epoch in range(rnn.num_epoches):
        print('Epoch: {}'.format(epoch + 1))
        batch_train = batch_iter(x_train, y_train, batch_size=rnn.batch_size)
        for x_batch, y_batch in batch_train:
            feed_dict = {
                rnn.input_x: x_batch,
                rnn.input_y: y_batch,
                rnn.keep_prob: rnn.dropout_keep_prob
            }

            if total_batch % rnn.save_per_batch == 0:
                # 每多少次迭代结果写入tensorboard
                s = session.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s)

            if total_batch % rnn.print_per_batch == 0:
                # 每多少次迭代打印结果
                feed_dict[rnn.keep_prob] = 1.0
                loss_train, acc_train = session.run([rnn.loss, rnn.acc],
                                                    feed_dict=feed_dict)
                loss_test, acc_test = evaluate(rnn, session, x_test, y_test)

                if acc_test > best_acc_test:
                    # 保存最好结果
                    best_acc_test = acc_test
                    saver.save(sess=session, save_path=FLAGS.save_path)

                time_dif = get_time_dif(start_time)
                # ^, <, > 分别是居中、左对齐、右对齐,后面带宽度,
                # : 号后面带填充的字符,只能是一个字符,不指定则默认是用空格填充。
                msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%}, Test Loss: ' \
                      '{3:>6.2}, Test Acc: {4:>7.2%}, Time: {5}'
                print(
                    msg.format(total_batch, loss_train, acc_train, loss_test,
                               acc_test, time_dif))

            session.run(rnn.optim, feed_dict=feed_dict)
            total_batch += 1