with open(os.path.join(save_dir, "result.json"), 'w') as f:
        json.dump(result_dict, f)
    x = range(epoch_final)
    plt.cla()
    plt.plot(x, train_loss_list, label="train loss")
    plt.plot(x, test_loss_list, label="test loss")
    plt.legend()
    plt.savefig(os.path.join(save_dir, "loss.png"))
    plt.cla()
    plt.plot(x, train_acc_list, label="train acc")
    plt.plot(x, test_acc_list, label="test acc")
    plt.legend()
    plt.savefig(os.path.join(save_dir, "acc.png"))


train_data, test_data, train_label, test_label = load_mnist_2d('data')

with open("config.json", 'r') as f:
    config_all = json.load(f)

while len(config_all['train_config']) != 0:
    train_config = config_all['train_config'][0]
    config = copy.copy(config_all['default_config'])
    for key, value in train_config.items():
        config[key] = value
    LOG_INFO('Using config %s now' % (config['name']))
    model, loss = build_model(config)
    train_and_save(model, loss, train_data, test_data, train_label, test_label)
    config_all['finish_config'].append(train_config)
    del config_all['train_config'][0]
    with open("config.json", 'w') as f:
    loss /= times
    acc /= times
    return acc, loss


def inference(model, sess, X):  # Test Process
    return sess.run([model.pred], {model.x_: X, model.keep_prob: 1.0})[0]


with tf.Session() as sess:
    if not os.path.exists(FLAGS.train_dir):
        os.mkdir(FLAGS.train_dir)
    if not os.path.exists(FLAGS.img_dir):
        os.mkdir(FLAGS.img_dir)
    if FLAGS.is_train:
        X_train, X_test, y_train, y_test = load_mnist_2d(FLAGS.data_dir)
        X_val, y_val = X_train[50000:], y_train[50000:]
        X_train, y_train = X_train[:50000], y_train[:50000]
        mlp_model = Model(True)
        if tf.train.get_checkpoint_state(FLAGS.train_dir):
            mlp_model.saver.restore(
                sess, tf.train.latest_checkpoint(FLAGS.train_dir))
        else:
            tf.global_variables_initializer().run()

        pre_losses = [1e18] * 3
        best_val_acc = 0.0

        loss_plot_list = []
        acc_plot_list = []