def test(time_str):
    config = conf_utils.init_test_config(time_str)
    batch_size = config["batch_size"]

    test_input, test_target = data_utils.gen_test_data(config)
    target_vocab = data_utils.get_vocab(config["target_vocab_path"])

    print(">> build model...")
    model = cnn.Model(config)
    _, pred, _ = model.cnn()

    with tf.Session() as sess:
        saver = tf.train.Saver()
        lastest_checkpoint_name = tf.train.latest_checkpoint(
            config["model_path"])
        print(f">> last checkpoint: {lastest_checkpoint_name}")
        saver.restore(sess, lastest_checkpoint_name)

        batch_gen = batch_utils.make_batch(test_input, test_target, batch_size,
                                           False)
        input_target_list = []
        pred_target_list = []
        for batch_num in range(len(test_input) // batch_size):
            test_input_batch, test_target_batch = batch_gen.__next__()

            pred_target_arr = sess.run(pred,
                                       feed_dict={
                                           model.input_holder:
                                           test_input_batch,
                                           model.target_holder:
                                           test_target_batch
                                       })

            input_target_arr = np.argmax(test_target_batch, 1)
            input_target_list.extend(input_target_arr.tolist())
            pred_target_list.extend(pred_target_arr.tolist())

        input_target_list = [
            target_vocab[i_data] for i_data in input_target_list
        ]
        pred_target_list = [
            target_vocab[p_data] for p_data in pred_target_list
        ]
        report = metrics.classification_report(input_target_list,
                                               pred_target_list)
        print(f"\n>> REPORT:\n{report}")
        output_utils.save_metrics(config, "report.txt", report)

        cm = metrics.confusion_matrix(input_target_list, pred_target_list)
        print(f"\n>> Confusion Matrix:\n{cm}")
        output_utils.save_metrics(config, "confusion_matrix.txt", str(cm))
def get_server_sess(time_str):
    config = conf_utils.init_test_config(time_str)
    input_vocab = data_utils.get_vocab(config["input_vocab_path"])
    target_vocab = data_utils.get_vocab(config["target_vocab_path"])

    print(">> build model...")
    model = cnn.Model(config)
    _, pred, _ = model.cnn()

    sess = tf.InteractiveSession()
    saver = tf.train.Saver()
    lastest_checkpoint_name = tf.train.latest_checkpoint(config["model_path"])
    print(f">> last checkpoint: {lastest_checkpoint_name}")
    saver.restore(sess, lastest_checkpoint_name)
    return sess, pred, target_vocab, input_vocab, model
def train(config_path):
    config = conf_utils.init_train_config(config_path)
    batch_size = config["batch_size"]
    epoch_size = config["epoch_size"]
    num_save_epoch = config["num_save_epoch"]

    train_input, train_target, validate_input, validate_target = data_utils.gen_train_data(
        config)

    validate_input_batch = validate_input[:batch_size]
    validate_target_batch = validate_target[:batch_size]

    print(">> build model...")
    model = cnn.Model(config)
    train_step, pred, cost = model.cnn()

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver()

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(config["log_train_path"],
                                             sess.graph)

        max_val_acc = 0
        max_key = 0
        for epoch in range(epoch_size):
            epoch = epoch + 1
            batch_gen = batch_utils.make_batch(train_input, train_target,
                                               batch_size)
            for batch_num in range(len(train_input) // batch_size):
                train_input_batch, train_target_batch = batch_gen.__next__()

                _, loss, train_pred, summary = sess.run(
                    [train_step, cost, pred, merged],
                    feed_dict={
                        model.input_holder: train_input_batch,
                        model.target_holder: train_target_batch
                    })
                train_writer.add_summary(summary)

                if not batch_num % 5:
                    input_train_arr = np.argmax(train_target_batch, 1)
                    target_train_arr = np.array(train_pred)
                    acc_train = np.sum(input_train_arr == target_train_arr
                                       ) * 100 / len(input_train_arr)

                    validate_pred = sess.run(
                        [pred],
                        feed_dict={
                            model.input_holder: validate_input_batch,
                            model.target_holder: validate_target_batch
                        })
                    input_validate_arr = np.argmax(validate_target_batch, 1)
                    target_validate_arr = np.array(validate_pred)
                    acc_val = np.sum(input_validate_arr == target_validate_arr
                                     ) * 100 / len(input_validate_arr)
                    print(
                        f">> e:{epoch:3} s:{batch_num:2} loss:{loss:5.4} acc_t: {acc_train:3f} acc_v: {acc_val:3f}"
                    )

                    if acc_val > max_val_acc:
                        max_val_acc = acc_val
                        max_key = 0
                    else:
                        max_key += 1

            if not epoch % num_save_epoch:
                saver.save(sess,
                           config["model_path"] + "model",
                           global_step=epoch)
                print(">> save model...")

            # 1000 batch val acc 没有增长,提前停止
            if max_key > 200:
                print(">> No optimization for a long time, auto stopping...")
                break

        time_str = config["time_now"]
        print(
            f">> use this command for test:\npython -m run.tensorflow_cnn test {time_str} "
        )