示例#1
0
def main(_):

    #load data
    index_to_char, char_to_index = getDicts(vocab_size)
    data = read_poems(char_to_index)

    print(data.shape)
    train_data = data
    val_data = data[0:10000]

    with tf.Graph().as_default(), tf.Session() as session:
        initializer = tf.random_uniform_initializer(-0.1, 0.1)
        with tf.variable_scope("trainModel",
                               reuse=None,
                               initializer=initializer):
            t_train = trainModel(training=True)
        with tf.variable_scope("trainModel",
                               reuse=True,
                               initializer=initializer):
            t_valid = trainModel(training=False)

        tf.initialize_all_variables().run()

        print("-------", type(tf.all_variables()))

        for a in tf.all_variables():
            print(a.name)

        print("-------")

        # number of epoch for training
        for i in range(100):
            #let learning rate decay
            learning_decay = lr_decay**max(i, 0.0)
            tf.assign(t_train._learning_rate, learning_decay).eval()
            print("Epoch ", i + 1)
            #train
            train_perplexity = train(session,
                                     t_train,
                                     train_data,
                                     t_train._train,
                                     index_to_char,
                                     verbose=True)
            print("train_perplexity: ", train_perplexity)
            checkpoint_path = os.path.join("", 'model.ckpt')
            t_train._saver.save(session, checkpoint_path, global_step=i)
            #print("have saved checkpoint")
            #validate
            val_perplexity = train(session, t_valid, val_data, tf.no_op(),
                                   index_to_char)
            print("val_perplexity: ", val_perplexity)
示例#2
0
def main(_):

    #load data
    index_to_char, char_to_index = getDicts(vocab_size)
    data = read_poems(char_to_index)
    
    #print(data.shape)
    train_data = data[1000:147541]
    val_data = data[0:1000]

    with tf.Graph().as_default(), tf.Session() as session:
        initializer = tf.random_uniform_initializer(-0.1, 0.1)
        with tf.variable_scope("trainModel", reuse=None, initializer=initializer):
            t_train = trainModel(training=True)
        with tf.variable_scope("trainModel", reuse=True, initializer=initializer):
            t_valid = trainModel(training=False)
  
        tf.initialize_all_variables().run()

        print("-------", type(tf.all_variables()))

        for a in tf.all_variables():
            print(a.name)

        print("-------")

        for i in range(1):
            #let learning rate decay 
            learning_decay = lr_decay ** max(i, 0.0)
            tf.assign(t_train._learning_rate, learning_decay).eval()
            print("Epoch ", i+1)
            #train
            train_perplexity = train(session, t_train, train_data, t_train._train, index_to_char, verbose=True)
            print("train_perplexity: ", train_perplexity)
            checkpoint_path = os.path.join("", 'model.ckpt')
            t_train._saver.save(session, checkpoint_path, global_step=i)
            #print("have saved checkpoint")
            #validate
            val_perplexity = train(session, t_valid, val_data, tf.no_op(), index_to_char)
            print("val_perplexity: ", val_perplexity)
示例#3
0
import os
import sys

from reader import getDicts
from reader import read_poems
from train import trainModel

checkpoint_dir = os.path.join('.')
exclusion = ['*']

print 'Character: ', sys.argv[1]

vocab_size = 2000

index_to_char, char_to_index = getDicts(vocab_size)
data = read_poems(char_to_index)
with tf.variable_scope("trainModel"):
    model = trainModel(training=False, infer=True)

with tf.Session() as sess:
    tf.initialize_all_variables().run()
    saver = tf.train.Saver(tf.all_variables())
    print '-------', type(tf.all_variables())
    for a in tf.all_variables():
        print a.name
    print '-------'

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

    print 'ckpt.model_checkpoint_path: ', ckpt.model_checkpoint_path