Пример #1
0
def run_training():
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)

    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size,
                                                     poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=output_targets,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=64,
                           learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')
        try:
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                n_chunk = len(poems_vector) // FLAGS.batch_size
                #for batch in range(n_chunk):
                #    loss, _, _ = sess.run([
                #        end_points['total_loss'],
                #        end_points['last_state'],
                #        end_points['train_op']
                #    ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
                #    n += 1
                #    print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
                if epoch % 6 == 0:
                    saver.save(sess,
                               os.path.join(FLAGS.model_dir,
                                            FLAGS.model_prefix),
                               global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess,
                       os.path.join(FLAGS.model_dir, FLAGS.model_prefix),
                       global_step=epoch)
            print(
                '## Last epoch were saved, next time will start from epoch {}.'
                .format(epoch))
Пример #2
0
    def __init__(self, model_name, model_dir, corpus_file, substr_len):
        self.model = model_name
        self.model_dir = model_dir
        self.corpus_file = corpus_file
        self.log_dir = "./log/predict/%s" % self.model
        assert substr_len
        self.substr_len = substr_len + 2

        print('## loading corpus from %s' % self.model_dir)
        # 导入语料
        # poems_vector: 二维ndarray, 语料矩阵, 每行为一个数据, 其中每个字用对应的序号表示
        # word_to_int: dict, 字到对应序号的映射
        # vocabularies: 单词表, 出现频率由高到低
        poems_vector, self.word_int_map, self.vocabularies = process_poems(self.corpus_file)

        # 生成RNN模型
        graph = tf.Graph()
        with graph.as_default():
            self.input_data = tf.placeholder(tf.int32, [1, 2, 1], name='character')
            self.pos_mat = tf.placeholder(tf.int32, [1, 2, 1], name='position')
            rnn = RNNModel(
                self.model, num_layers=2, rnn_size=64, batch_size=64, vocabularies=self.vocabularies, 
                add_dim=add_feature_dim, substr_len=self.substr_len
            )
            self.endpoints = rnn.predict(input_data=self.input_data, add_data=self.pos_mat)
            saver = tf.train.Saver(tf.global_variables())
            init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        self.sess = tf.Session(graph = graph)
        self.sess.run(init_op)       # init

        # 检查最近的checkpoint
        checkpoint = tf.train.latest_checkpoint(self.model_dir)
        # 从中复原
        saver.restore(self.sess, checkpoint)
Пример #3
0
def gen_poem(begin_word):
    batch_size = 1
    print('## loading corpus from %s' % model_dir)
    poems_vector, word_int_map, vocabularies = process_poems(corpus_file)

    input_data = tf.placeholder(tf.int32, [batch_size, None])

    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=None,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=64,
                           learning_rate=lr)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        checkpoint = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess, checkpoint)

        x = np.array([list(map(word_int_map.get, start_token))])

        [predict, last_state
         ] = sess.run([end_points['prediction'], end_points['last_state']],
                      feed_dict={input_data: x})
        if begin_word:
            word = begin_word
        else:
            word = to_word(predict, vocabularies)
        poem_ = ''

        i = 0
        while word != end_token:
            poem_ += word
            i += 1
            if i >= 24:
                break
            x = np.zeros((1, 1))
            x[0, 0] = word_int_map[word]
            [predict, last_state
             ] = sess.run([end_points['prediction'], end_points['last_state']],
                          feed_dict={
                              input_data: x,
                              end_points['initial_state']: last_state
                          })
            word = to_word(predict, vocabularies)

        return poem_
Пример #4
0
def run_training():
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)

    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')
        try:
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                n_chunk = len(poems_vector) // FLAGS.batch_size
                for batch in range(n_chunk):
                    loss, _, _ = sess.run([
                        end_points['total_loss'],
                        end_points['last_state'],
                        end_points['train_op']
                    ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
                    n += 1
                    print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
                if epoch % 6 == 0:
                    saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
            print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
Пример #5
0
def gen_poem(begin_word):
    batch_size = 1
    print('## loading corpus from %s' % model_dir)
    poems_vector, word_int_map, vocabularies = process_poems(corpus_fileh)

    input_data = tf.placeholder(tf.int32, [batch_size, None])

    end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        checkpoint = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess, checkpoint)

        x = np.array([list(map(word_int_map.get, start_token))])

        [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                         feed_dict={input_data: x})
        if begin_word:
            word = begin_word
        else:
            word = to_word(predict, vocabularies)
        poem_ = ''

        i = 0
        while word != end_token:
            poem_ += word
            i += 1
            if i >= 24:
                break
            x = np.zeros((1, 1))
            x[0, 0] = word_int_map[word]
            [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                             feed_dict={input_data: x, end_points['initial_state']: last_state})
            word = to_word(predict, vocabularies)

        return poem_
Пример #6
0
 def __init__(self, s_word):
     self.begin_word = s_word
     self.start_token = 'B'
     self.end_token = 'E'
     self.model_dir = './model/'
     self.corpus_file = './data/poems.txt'
     self.lr = 0.0002
     self.RATE = 0.008
     self.max_len = 78
     self.poems_vector, self.word_int_map, self.vocabs = process_poems(
         self.corpus_file)
     self.input_data = tf.placeholder(tf.int32, [1, None])
     self.end_points = rnn_model(model='lstm',
                                 input_data=self.input_data,
                                 output_data=None,
                                 vocab_size=len(self.vocabs),
                                 rnn_size=128,
                                 num_layers=2,
                                 batch_size=64,
                                 learning_rate=self.lr)
     self._parse_input()
Пример #7
0
def run_training():
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    poems_vector, word_to_int, vocabularies = process_poems(corpus_path)
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size,
                                                     poems_vector, word_to_int)

    print("## top ten vocabularies: %s" % str(vocabularies[:10]))
    print("## tail ten vocabularies: %s" % str(vocabularies[-10:]))
    print("## len(first vector)=%d, first vector[:50]: %s" %
          (len(poems_vector[0]), poems_vector[0][:50]))
    print("## len(last vector)=%d, second vector[:50]: %s" %
          (len(poems_vector[-1]), poems_vector[-1][:50]))

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=output_targets,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=FLAGS.batch_size,
                           learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    summary_op = tf.summary.merge_all()
    with tf.Session() as sess:

        train_writer = tf.summary.FileWriter(os.path.join(log_dir, "train"),
                                             sess.graph)

        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint),
                  flush=True)
            start_epoch += int(checkpoint.split('-')[-1]) + 1
        print('## start training...', flush=True)

        n_chunk = len(poems_vector) // FLAGS.batch_size

        try:
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0

                for batch in range(n_chunk):
                    step = epoch * n_chunk + batch
                    if step % FLAGS.print_every_steps == 0:
                        loss, _, _, train_summary = sess.run(
                            [
                                end_points['total_loss'],
                                end_points['last_state'],
                                end_points['train_op'], summary_op
                            ],
                            feed_dict={
                                input_data: batches_inputs[n],
                                output_targets: batches_outputs[n]
                            })
                        train_writer.add_summary(train_summary,
                                                 global_step=step)
                        print(
                            '[%s] Step: %d, Epoch: %d, batch: %d, training loss: %.6f'
                            % (time.strftime('%Y-%m-%d %H:%M:%S'), step, epoch,
                               batch, loss),
                            flush=True)
                    else:
                        _, _ = sess.run(
                            [end_points['last_state'], end_points['train_op']],
                            feed_dict={
                                input_data: batches_inputs[n],
                                output_targets: batches_outputs[n]
                            })
                    n += 1
                    step += 1
                if epoch % FLAGS.save_every_epoch == 0:
                    saver.save(sess, model_file, global_step=epoch)
                    print("[%s] Saving checkpoint for epoch %d" %
                          (time.strftime('%Y-%m-%d %H:%M:%S'), epoch),
                          flush=True)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, model_file, global_step=epoch)
            print(
                '## Last epoch were saved, next time will start from epoch {}.'
                .format(epoch),
                flush=True)
Пример #8
0
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')

FLAGS = tf.app.flags.FLAGS


def run_training():
    if not os.path.exists(FLAGS.model_dir):  //没有路径
        os.makedirs(FLAGS.model_dir)  //创造路径

    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)  //前置处理,生成词向量
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) //分成batch

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)  //建立RNN模型

    saver = tf.train.Saver(tf.global_variables())  //
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        sess.run(init_op)
Пример #9
0
    def run(self):
        batch_size = 1
        begin_word = None
        print('## loading corpus from %s' % model_dir)
        poems_vector, word_int_map, vocabularies = process_poems(corpus_file)

        input_data = tf.placeholder(tf.int32, [batch_size, None])

        end_points = rnn_model(model='lstm',
                               input_data=input_data,
                               output_data=None,
                               vocab_size=len(vocabularies),
                               rnn_size=128,
                               num_layers=2,
                               batch_size=64,
                               learning_rate=lr)

        saver = tf.train.Saver(tf.global_variables())
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        with tf.Session() as sess:
            sess.run(init_op)

            checkpoint = tf.train.latest_checkpoint(model_dir)
            saver.restore(sess, checkpoint)

            x = np.array([list(map(word_int_map.get, start_token))])
            while True:
                if 100 > len(poemList):
                    lock = False
                    try:
                        [predict,
                         last_state] = sess.run([
                             end_points['prediction'], end_points['last_state']
                         ],
                                                feed_dict={input_data: x})
                        word = begin_word or to_word(predict, vocabularies)
                        poem_ = ''

                        i = 0
                        poem_ = ''
                        while word != end_token:
                            poem_ += word
                            i += 1
                            if i > 40:
                                break
                            x = np.array([[word_int_map[word]]])
                            [predict, last_state] = sess.run(
                                [
                                    end_points['prediction'],
                                    end_points['last_state']
                                ],
                                feed_dict={
                                    input_data: x,
                                    end_points['initial_state']: last_state
                                })
                            word = to_word(predict, vocabularies)

                        poem = pretty_print_poem(poem_)
                        if "" != poem and "                                         。" != poem:
                            lock = True
                            mutex.acquire()
                            poemList.append(poem)
                    finally:
                        if lock:
                            mutex.release()
                else:
                    time.sleep(1)
Пример #10
0
def gen_poem(begin_word):
    # 根据首个汉字作诗
    # 作诗时, batch_size设为1
    batch_size = 1
    print('## loading corpus from %s' % model_dir)

    # 读取诗集文件
    # 依次得到数字ID表示的诗句、汉字-ID的映射map、所有的汉字的列表
    poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
    # 声明输入的占位符
    input_data = tf.placeholder(tf.int32, [batch_size, None])
    # 通过rnn模型得到结果状态集
    end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
    # 初始化saver和session
    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        # 加载上次的模型参数
        checkpoint = tf.train.latest_checkpoint(model_dir)

        # 注: 无模型参数时, 该步直接crash, 强制有训练好的模型参数
        saver.restore(sess, checkpoint)
        # 取出诗文前缀(G)对应的索引值所谓初始输入
        x = np.array([list(map(word_int_map.get, start_token))])
        # 得出预测值和rnn的当前状态
        [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                         feed_dict={input_data: x})
        if begin_word:
            # 用户输入值赋值给word
            word = begin_word
        else:
            # 若未输入, 则取初始预测值的词向量
            word = to_word(predict, vocabularies)

        # 初始化作诗结果变量
        poem_ = ''

        i = 0
        # 未到结束符时, 一直预测下一个词
        while word != end_token:
            # 没预测一个则追加到结果上
            poem_ += word
            i += 1
            if i >= 24:
                break
            # 初始化输入为[[0]]
            x = np.zeros((1, 1))

            # 赋值为当前word对应的索引值
            x[0, 0] = word_int_map[word]

            # 根据当前词和当前的上下文状态(last_state)进行预测
            # 返回的结果是预测值和最新的上下文状态
            [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                             feed_dict={input_data: x, end_points['initial_state']: last_state})

            # 根据预测值得出词向量
            word = to_word(predict, vocabularies)

        return poem_
Пример #11
0
def run_training():
    # dir to save model

    if os.path.exists(FLAGS.model_dir):
        os.mkdir(FLAGS.model_dir)

    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    batches_input, batches_outputs = generate_batch(FLAGS.batch_size,
                                                    poems_vector, word_to_int)

    # print(word_to_int)
    # print(batches_input[0][0])
    # print(batches_outputs[0][1])
    # print(batches_outputs)
    # time.sleep(10000)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model="lstm",
                           input_data=input_data,
                           output_data=output_targets,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=64,
                           learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    with tf.Session() as sess:

        # 初始化所有的变量
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)

        if checkpoint:
            saver.restore(sess, checkpoint)
            print("### restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])

        print(' ## start training... ')

        try:

            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                n_chunk = len(poems_vector) // FLAGS.batch_size
                for batch in range(n_chunk):
                    loss, _, _ = sess.run(
                        [
                            end_points['total_loss'], end_points['last_state'],
                            end_points['train_op']
                        ],
                        feed_dict={
                            input_data: batches_input[n],
                            output_targets: batches_outputs[n]
                        })
                    n += 1
                    print('Epoch: %d, batch: %d, training loss: %.6f' %
                          (epoch, batch, loss))

                if epoch % 6:
                    saver.save(sess,
                               os.path.join(FLAGS.model_dir,
                                            FLAGS.model_prefix),
                               global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess,
                       os.path.join(FLAGS.model_dir, FLAGS.model_prefix),
                       global_step=epoch)
            print(
                '## Last epoch were saved, next time will start from epoch {}.'
                .format(epoch))
Пример #12
0
def run_training():
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)

    # poems_vector: 三维ndarray, 语料矩阵, 每层为一行诗, 分上下句(2x?). 其中每个字用对应的序号表示
    # word_to_int: pair of dict, 字到对应序号的映射
    # vocabularies: pair of list, 单词表, 出现频率由高到低
    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)

    _, _, substr_len = poems_vector.shape
    # 语料矩阵按batch_size分为若干chunk.
    # batches_inputs: 四维ndarray, 每块为一chunk, 其中每层为一个数据(2 * substr_len)
    # batches_outputs: 四维ndarray, batches_inputs向左平移一位得到
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size,
                                                     poems_vector, word_to_int)

    graph = tf.Graph()
    with graph.as_default():
        # declare placeholders of shape of (batch_size, 2, substr_len)
        input_data = tf.placeholder(tf.int32,
                                    [FLAGS.batch_size, 2, substr_len],
                                    name="left_word")
        output_targets = tf.placeholder(tf.int32,
                                        [FLAGS.batch_size, 2, substr_len],
                                        name="right_word")
        add_mat = tf.placeholder(tf.int32, [FLAGS.batch_size, 2, substr_len],
                                 name="additional_feature")
        # 取得模型
        rnn = RNNModel(model_name,
                       num_layers=2,
                       rnn_size=64,
                       batch_size=64,
                       vocabularies=vocabularies,
                       add_dim=add_feature_dim,
                       substr_len=substr_len)
        # get 2 endpoints
        endpoints = rnn.train(input_data=input_data,
                              add_data=add_mat,
                              label_data=output_targets,
                              learning_rate=FLAGS.learning_rate)
        # 只保存一个文件
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

    # session配置
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config, graph=graph) as sess:
        # init
        sess.run(init_op)
        # log
        summary_writer = tf.summary.FileWriter(FLAGS.log_path, graph=graph)

        # start_epoch, 训练完的趟数
        start_epoch = 0
        # 建立checkpoint
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        os.system('cls')
        if checkpoint:
            # 从检查点中恢复
            saver.restore(sess, checkpoint)
            print("## restore from checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])

        print('## start training...')
        print("## run `tensorboard --logdir %s`, and view localhost:6006." %
              (os.path.abspath("./log/train/%s" % model_name)))
        # n_chunk, chunk大小
        n_chunk = len(poems_vector) // FLAGS.batch_size
        tf.get_default_graph().finalize()
        for epoch in range(start_epoch, FLAGS.epochs):
            bar = Bar("epoch%d" % epoch, max=n_chunk)
            for batch in range(n_chunk):
                # train the both model
                summary = easyTrain(
                    sess,
                    endpoints,
                    inputs=(input_data, batches_inputs[batch]),
                    label=(output_targets, batches_outputs[batch]),
                    pos_data=(add_mat,
                              generate_add_mat(batches_inputs[batch],
                                               'binary')))
                # reduce IO
                if batch % 16 == 0:
                    summary_writer.add_summary(summary,
                                               epoch * n_chunk + batch)
                    bar.next(16)
            # save at the end of each epoch
            saver.save(sess,
                       os.path.join(FLAGS.model_dir, FLAGS.model_prefix),
                       global_step=epoch)
            bar.finish()
        # save on exit
        saver.save(sess,
                   os.path.join(FLAGS.model_dir, FLAGS.model_prefix),
                   global_step=epoch)
        print('## Last epoch were saved, next time will start from epoch {}.'.
              format(epoch))
Пример #13
0
def gen_poem():
    batch_size = 1

    poems_vector, word_int_map, vocabularies = process_poems(corpus_path)

    input_data = tf.placeholder(tf.int32, [batch_size, None])

    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=None,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=64,
                           learning_rate=lr)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        checkpoint = tf.train.latest_checkpoint(model_dir)
        print('## loading corpus from %s' % checkpoint)
        saver.restore(sess, checkpoint)

        while True:

            # start_token=vocabularies[np.random.randint(len(vocabularies))]
            # x = np.array([list(map(word_int_map.get, start_token))])
            # [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
            #                                  feed_dict={input_data: x})
            begin_word = input('## please input the first character:')
            if begin_word and begin_word in vocabularies:
                word = begin_word
            else:
                word = vocabularies[np.random.randint(len(vocabularies))]
                print(
                    '## begin word not in vocabularies, use random begin word:'
                    + word)

            poem_ = word

            x = np.array([list(map(word_int_map.get, word))])
            [predict, last_state
             ] = sess.run([end_points['prediction'], end_points['last_state']],
                          feed_dict={input_data: x})
            # second word
            word = to_word(predict, vocabularies)

            i = 1
            while i < FLAGS.gen_sequence_len:
                poem_ += word
                i += 1
                x = np.zeros((1, 1))
                x[0, 0] = word_int_map[word]
                [predict, last_state] = sess.run(
                    [end_points['prediction'], end_points['last_state']],
                    feed_dict={
                        input_data: x,
                        end_points['initial_state']: last_state
                    })
                word = to_word(predict, vocabularies)

            print(poem_, flush=True)