Exemple #1
0
def evaluate(sess, x_, y_, wv_model=None):
    '''评估数据的准确率和损失'''
    data_len = len(x_)
    batch_eval = batch_iter(x_, y_, wv_model, config.batch_size)
    total_loss = 0.0
    total_acc = 0.0
    for x_batch, y_batch in batch_eval:
        batch_len = len(x_batch)
        feed_dict = feed_data(x_batch, y_batch, 1.0)
        loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
        total_loss += loss * batch_len
        total_acc += acc * batch_len

    return total_loss / data_len, total_acc / data_len
Exemple #2
0
def train2(restore=False):
    logging.info('Configuring TensorBoard and Saver...')
    # 配置tensor board
    tensorboard_dir = 'text_cnn/tmp'
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    tf.summary.scalar('loss', model.loss)
    tf.summary.scalar('accuracy', model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)

    #配置Saver
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    logging.info('Loading training and validation data...')
    # 载入训练集与验证集
    start_time = time.time()
    train_set_data = process_file(train_set_file, config.num_classes,
                                  config.seq_length)
    x_train, y_train = train_set_data['train_set']
    x_val, y_val = train_set_data['validate_set']
    del train_set_data
    logging.info('Time usage: {}'.format(get_time_dif(start_time, )))

    # 创建session
    conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    session = tf.Session(config=conf)
    if restore and os.path.exists(save_dir + "checkpoint"):
        logging.info("Restoring Variables from Checkpoint for cnn model.")
        saver.restore(session, tf.train.latest_checkpoint(save_dir))
    else:
        logging.info('first training cnn model, Initializing Variables')
        session.run(tf.global_variables_initializer())

    writer.add_graph(session.graph)

    logging.info('Training and evaluating...')
    start_time = time.time()

    best_acc_val = 0.0

    flag = False
    # wv_model = get_wordvec_model()
    for epoch in range(1, config.num_epochs + 1):
        logging.info('Epoch: {}'.format(epoch))
        batch_train = batch_iter(x_train, y_train, wv_model, config.batch_size)
        total_batch = 0
        for x_batch, y_batch in batch_train:
            feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)

            if total_batch % config.print_per_batch == 0:
                loss_train, acc_train = session.run([model.loss, model.acc],
                                                    feed_dict=feed_dict)
                time_dif = get_time_dif(start_time)
                logging.info(
                    'epoch:{4: >3}, Iter: {0:>6}, Train Loss: {1:>6.8}, Train Acc: {2:>7.8%}, Time: {3}'
                    .format(total_batch, loss_train, acc_train, time_dif,
                            epoch))

            # 运行优化
            session.run(model.optim, feed_dict=feed_dict)
            total_batch += 1

        feed_dict[model.keep_prob] = 1.0
        loss_train, acc_train = session.run([model.loss, model.acc],
                                            feed_dict=feed_dict)
        loss_val, acc_val = evaluate(session, x_val, y_val, wv_model=wv_model)

        if acc_val > best_acc_val:
            best_acc_val = acc_val
            last_improved = total_batch
            saver.save(sess=session, save_path=save_path)
            improved_str = '*'
        else:
            improved_str = ''

        time_dif = get_time_dif(start_time)
        msg = 'epoch:{0: >3}, Iter: {1:>6}, Train Loss: {2:>6.8}, Train Acc: {3:>7.8%},' \
              'Val Loss: {4:>6.8}, Val Acc: {5:>7.8%}, Time: {6} {7}'
        logging.info(
            msg.format(epoch, total_batch, loss_train, acc_train, loss_val,
                       acc_val, time_dif, improved_str))

        if epoch % config.save_per_epoch == 0:
            s = session.run(merged_summary, feed_dict=feed_dict)
            writer.add_summary(s, total_batch)
            # epoch_name = os.path.join(save_dir, "epoch_{0}".format(epoch))
            # saver.save(sess=session, save_path=epoch_name)
    session.close()
Exemple #3
0
def train(restore=False):
    logging.info('Configuring TensorBoard and Saver...')
    # 配置tensor board
    tensorboard_dir = 'textcnn/tensorboard'
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    tf.summary.scalar('loss', model.loss)
    tf.summary.scalar('accuracy', model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)

    #配置Saver
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    logging.info('Loading training and validation data...')
    # 载入训练集与验证集
    start_time = time.time()
    train_set_data = process_file(train_set_file, config.num_classes, config.seq_length)
    x_train, y_train = train_set_data['train_set']
    x_val, y_val = train_set_data['validate_set']
    logging.info('Time usage: {}'.format(get_time_dif(start_time,)))
    random_vector_generate(train_set_file)

    # 创建session
    conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    session = tf.Session(config=conf)
    if restore and os.path.exists(save_dir+"checkpoint"):
        logging.info("Restoring Variables from Checkpoint for cnn model.")
        saver.restore(session, tf.train.latest_checkpoint(save_dir))
    else:
        logging.info('first training cnn model, Initializing Variables')
        session.run(tf.global_variables_initializer())

    writer.add_graph(session.graph)

    logging.info('Training and evaluating...')
    start_time = time.time()
    total_batch = 0
    best_acc_val = 0.0
    last_improved = 0
    require_improvement = 10000

    flag = False
    for epoch in range(config.num_epochs):
        logging.info('Epoch: {}'.format(epoch + 1,))
        batch_train = batch_iter(x_train, y_train, config.batch_size)
        for x_batch, y_batch in batch_train:
            feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)

            if total_batch % config.save_per_batch == 0:
                # 训练结果写入tensorboard轮数
                s = session.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s, total_batch)

            if total_batch % config.print_per_batch == 0:
                # 输出训练集和验证集性能轮数
                feed_dict[model.keep_prob] = 1.0
                loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>6.4},  Train Acc: {2:>8.4%},  ' \
                      'Time: {3:>8.4}'
                logging.info(msg.format(total_batch, loss_train, acc_train, time_dif))

            # 运行优化
            session.run(model.optim, feed_dict=feed_dict)
            total_batch += 1

            # if total_batch - last_improved > require_improvement:
            #     # 验证集正确率长期不提升,提前结束训练
            #     logging.info('No optimization for a long time, auto-stopping...')
            #     flag = True
            #     break

        # 输出训练集和验证集性能轮数
        feed_dict[model.keep_prob] = 1.0
        loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
        loss_val, acc_val = evaluate(session, x_val, y_val)

        if acc_val > best_acc_val:
            best_acc_val = acc_val
            last_improved = total_batch
            saver.save(sess=session, save_path=save_path)
            improved_str = '*'
        else:
            improved_str = ''

        time_dif = get_time_dif(start_time)
        msg = 'Iter: {0:>6}, Train Loss: {1:>6.4}, Train Acc: {2:>7.4%},' \
              'Val Loss: {3:>6.4}, Val Acc: {4:>7.4%}, Time: {5} {6}'
        logging.info(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val,
                                time_dif, improved_str))

        if flag:
            break