예제 #1
0
 def input_fn():
     with tf.device('/cpu:0'):
         if is_train:
             dataset = train_dataset(opts.data_dir, fashion=opts.fashion)
             dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=5 * opts.batch_size, count=None))
         else:
             dataset = test_dataset(opts.data_dir, fashion=opts.fashion)
         dataset = dataset.batch(batch_size=opts.batch_size)
         iterator = dataset.make_one_shot_iterator()
         return iterator.get_next()
def main(_):
    print(
        "==============================================================================="
    )
    print("The input_dim is", FLAGS.input_dim, "The hidden_dim is",
          FLAGS.hidden_dim)
    print("The output_dim is", FLAGS.output_dim, "The keep_prob is",
          FLAGS.keep_prob)
    print("The batch_size is", FLAGS.batch_size, "The test is",
          FLAGS.test_batch_size)
    print("The model is", FLAGS.model, "The number of layer is", FLAGS.layer)
    print("The truncated number is", FLAGS.trun_num, "The reverse is",
          FLAGS.reverse)
    X_train, y_train, X_test, y_test, _, cpu_load_std = read_data(
        FLAGS.data_path, FLAGS.input_dim, FLAGS.output_dim, FLAGS.input_dim)
    with open("./data/axp0.pkl", 'rb') as f:
        axp0 = pickle.load(f)
    with open("./data/axp7.pkl", 'rb') as f:
        axp7 = pickle.load(f)
    with open("./data/sahara.pkl", 'rb') as f:
        sahara = pickle.load(f)
    with open("./data/themis.pkl", 'rb') as f:
        themis = pickle.load(f)
    _, _, X_axp0_test, y_axp0_test, std_axp0 = test_dataset(
        axp0, FLAGS.input_dim, FLAGS.output_dim, FLAGS.input_dim)
    _, _, X_axp7_test, y_axp7_test, std_axp7 = test_dataset(
        axp7, FLAGS.input_dim, FLAGS.output_dim, FLAGS.input_dim)
    _, _, X_sahara_test, y_sahara_test, std_sahara = test_dataset(
        sahara, FLAGS.input_dim, FLAGS.output_dim, FLAGS.input_dim)
    _, _, X_themis_test, y_themis_test, std_themis = test_dataset(
        themis, FLAGS.input_dim, FLAGS.output_dim, FLAGS.input_dim)

    train_data_len = X_train.shape[1]
    train_len, train_index = truncated_index(train_data_len, FLAGS.trun_num,
                                             FLAGS.reverse)
    print("train length", train_len)
    print(train_index)
    with tf.Graph().as_default(), tf.Session() as session:
        with tf.variable_scope("model", reuse=None):
            m_train = RNNModel(is_training=True,
                               batch_size=FLAGS.batch_size,
                               length=train_len)
        with tf.variable_scope("model", reuse=True):
            m_test = RNNModel(is_training=False,
                              batch_size=FLAGS.batch_size,
                              length=len(y_test[0]))
            m_new_test = RNNModel(is_training=False,
                                  batch_size=FLAGS.test_batch_size,
                                  length=len(y_test[0]))

        tf.initialize_all_variables().run()

        #log_a = Log('http://localhost:8120','modelA')
        # pastalog --serve 8120

        scale = cpu_load_std**2
        test_best = 0.0
        training_time = []
        for i in range(FLAGS.epoch):
            if i < FLAGS.epoch / 3:
                lr_decay = 1
            elif i < FLAGS.epoch * 2 / 3:
                lr_decay = 0.1
            else:
                lr_decay = 0.01
            m_train.assign_lr(session, FLAGS.lr * lr_decay)
            train_loss_list = []
            train_state_list = []
            start = time.time()
            for j in range(FLAGS.trun_num):
                train_loss, train_state = run_train_epoch(
                    session, m_train, X_train[:, train_index[j], :],
                    y_train[:, train_index[j], :], m_train.train_op)
                train_loss_list.append(train_loss)
                if FLAGS.reverse:
                    if j == 0:
                        train_state_list.append(train_state)
                else:
                    if j == FLAGS.trun_num - 1:
                        train_state_list.append(train_state)
            finish = time.time()
            training_time.append(finish - start)
            test_loss, _ = run_test_epoch(session, m_test, X_test, y_test,
                                          tf.no_op(), train_state_list[0])
            axp0_loss = run_new_load(session, m_new_test, X_axp0_test,
                                     y_axp0_test, tf.no_op())
            axp7_loss = run_new_load(session, m_new_test, X_axp7_test,
                                     y_axp7_test, tf.no_op())
            sahara_loss = run_new_load(session, m_new_test, X_sahara_test,
                                       y_sahara_test, tf.no_op())
            themis_loss = run_new_load(session, m_new_test, X_themis_test,
                                       y_themis_test, tf.no_op())
            if i == 0:
                test_best = test_loss
            if test_loss < test_best:
                test_best = test_loss


#            print("epoch:%3d, lr %.5f, train_loss_1 %.6f, train_loss_2 %.6f, test_loss %.6f" %
#                    (i + 1, session.run(m_train.lr), train_loss_1*scale,
#                     train_loss_2*scale, test_loss*scale))
#print(np.asarray(train_loss_list)*scale)
            print(
                "epoch:%3d, lr %.5f, train_loss %.6f, test_loss %.6f, speed %.2f seconds/epoch"
                % (i + 1, session.run(m_train.lr), np.mean(train_loss_list) *
                   scale, test_loss * scale, training_time[i]))
            print("axp0 loss %.6f, axp7 loss %.6f" %
                  (axp0_loss * std_axp0**2, axp7_loss * std_axp7**2))
            print("sahara loss %.6f, themis loss %.6f" %
                  (sahara_loss * std_sahara**2, themis_loss * std_themis**2))
            #log_a.post("trainLoss", value=float(train_loss), step=i)
            #log_a.post("testLoss", value=float(test_loss), step=i)
            if i == FLAGS.epoch - 1:
                print("Best test loss %.6f" % (test_best * scale))
                print("Average %.4f seconds for one epoch" %
                      (np.mean(training_time)))

    print("The input_dim is", FLAGS.input_dim, "The hidden_dim is",
          FLAGS.hidden_dim)
    print("The output_dim is", FLAGS.output_dim, "The keep_prob is",
          FLAGS.keep_prob)
    print("The batch_size is", FLAGS.batch_size, "The test is",
          FLAGS.test_batch_size)
    print("The keep_prob is", FLAGS.keep_prob, "The batch_size is",
          FLAGS.batch_size)
    print("The model is", FLAGS.model, "The number of layer is", FLAGS.layer)
    print("The truncated number is", FLAGS.trun_num, "The reverse is",
          FLAGS.reverse)
    print(
        "==============================================================================="
    )
예제 #3
0
def main(_):
    print(
        "==============================================================================="
    )
    print("The input_dim is", FLAGS.input_dim, "The hidden_dim is",
          FLAGS.hidden_dim)
    print("The output_dim is", FLAGS.output_dim)
    print("The batch_size is", FLAGS.batch_size, "The test is ",
          FLAGS.test_batch_size)
    print("The data_path is", FLAGS.data_path)
    X_train, y_train, X_test, y_test, _, cpu_load_std = read_data(
        FLAGS.data_path, FLAGS.input_dim, FLAGS.output_dim, FLAGS.input_dim)
    with open("./data/axp0.pkl", 'rb') as f:
        axp0 = pickle.load(f)
    with open("./data/axp7.pkl", 'rb') as f:
        axp7 = pickle.load(f)
    with open("./data/sahara.pkl", 'rb') as f:
        sahara = pickle.load(f)
    with open("./data/themis.pkl", 'rb') as f:
        themis = pickle.load(f)
    _, _, X_axp0, y_axp0, std_axp0 = test_dataset(axp0, FLAGS.input_dim,
                                                  FLAGS.output_dim,
                                                  FLAGS.input_dim)
    _, _, X_axp7, y_axp7, std_axp7 = test_dataset(axp7, FLAGS.input_dim,
                                                  FLAGS.output_dim,
                                                  FLAGS.input_dim)
    _, _, X_sahara, y_sahara, std_sahara = test_dataset(
        sahara, FLAGS.input_dim, FLAGS.output_dim, FLAGS.input_dim)
    _, _, X_themis, y_themis, std_themis = test_dataset(
        themis, FLAGS.input_dim, FLAGS.output_dim, FLAGS.input_dim)

    inSize = FLAGS.input_dim
    resSize = FLAGS.hidden_dim
    rho = 0.1
    #    rho = 0.9
    #    cr = 0.05
    Win = np.float32(np.random.rand(inSize, resSize) / 5 - 0.1)
    #    Win = np.float32(np.random.rand(inSize, resSize) - 0.5)
    N = resSize * resSize
    W = np.random.rand(N) - 0.5
    #    zero_index = np.random.permutation(N)[int(N * cr * 1.0):]
    #    W[zero_index] = 0
    W = W.reshape((resSize, resSize))
    rhoW = max(abs(linalg.eig(W)[0]))
    W *= rho / rhoW
    W = np.float32(W)

    with tf.Graph().as_default(), tf.Session() as session:
        with tf.variable_scope("model", reuse=None):
            m_train = ESN(is_training=True,
                          batch_size=FLAGS.batch_size,
                          length=len(y_train[0]))
        with tf.variable_scope("model", reuse=True):
            m_test = ESN(is_training=False,
                         batch_size=FLAGS.batch_size,
                         length=len(y_test[0]))
            m_new_test = ESN(is_training=False,
                             batch_size=FLAGS.test_batch_size,
                             length=len(y_test[0]))

        tf.initialize_all_variables().run()

        #log_a = Log('http://localhost:8120','modelA')
        # pastalog --serve 8120

        scale = cpu_load_std**2
        train_best = test_best = 0.0
        for i in range(FLAGS.epoch):
            if i < FLAGS.epoch / 3:
                lr_decay = 1
            elif i < FLAGS.epoch * 2 / 3:
                lr_decay = 0.1
            else:
                lr_decay = 0.01
            m_train.assign_lr(session, FLAGS.lr * lr_decay)
            train_loss, train_state = run_train_epoch(session, m_train, Win, W,
                                                      X_train, y_train[:,
                                                                       50:, :],
                                                      m_train.train_op)
            test_loss, _ = run_test_epoch(session, m_test, Win, W, X_test,
                                          y_test, tf.no_op(), train_state)
            axp0_loss = run_new_load(session, m_new_test, Win, W, X_axp0,
                                     y_axp0, tf.no_op())
            axp7_loss = run_new_load(session, m_new_test, Win, W, X_axp7,
                                     y_axp7, tf.no_op())
            sahara_loss = run_new_load(session, m_new_test, Win, W, X_sahara,
                                       y_sahara, tf.no_op())
            themis_loss = run_new_load(session, m_new_test, Win, W, X_themis,
                                       y_themis, tf.no_op())

            if i == 0:
                train_best = train_loss
                test_best = test_loss
            if train_loss < train_best:
                train_best = train_loss
            if test_loss < test_best:
                test_best = test_loss
            print(
                "epoch:%3d, learning rate %.5f, train_loss %.6f, test_loss %.6f"
                % (i + 1, session.run(
                    m_train.lr), train_loss * scale, test_loss * scale))
            print("axp0 loss %.6f, axp7 loss %.6f" %
                  (axp0_loss * std_axp0**2, axp7_loss * std_axp7**2))
            print("sahara loss %.6f, themis loss %.6f" %
                  (sahara_loss * std_sahara**2, themis_loss * std_themis**2))
            #log_a.post("trainLoss", value=float(train_loss), step=i)
            #log_a.post("testLoss", value=float(test_loss), step=i)
            if i == FLAGS.epoch - 1:
                print("Best train, test loss %.6f %.6f" %
                      (train_best * scale, test_best * scale))

    print("The input_dim is", FLAGS.input_dim, "The hidden_dim is",
          FLAGS.hidden_dim)
    print("The output_dim is", FLAGS.output_dim)
    print("The batch_size is", FLAGS.batch_size, "The test is ",
          FLAGS.test_batch_size)
    print("The data_path is", FLAGS.data_path)
    print(
        "==============================================================================="
    )