コード例 #1
0
ファイル: tang_poems.py プロジェクト: liaozhihui/lzh
def gen_poem(begin_word):
    batch_size=1
    print("[INFO]loading corpus from %s"%FLAGS.file_path)
    peoms_vector,word_int_map,vocabularies=process_poems(FLAGS.file_path)
    input_data=tf.placeholder(tf.int32,[batch_size,None])

    end_points = rnn_model(model="lstm",input_data=input_data,output_data=None,vocab_size=len(vocabularies),\
                           rnn_size=128,num_layers=2,batch_size=64,learning_rate=FLAGS.learning_rate)
    saver=tf.train.Saver(tf.global_variables())
    init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        checkpoint=tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
        saver.restore(sess,checkpoint)

        x =np.array([list(map(word_int_map.get,start_token))])
        [predict,last_state]=sess.run([end_points['prediction'],end_points['last_state']],\
                                      feed_dict={input_data:x})

        if begin_word:
            word = begin_word
        else:
            word=to_word(predict,vocabularies)
        poem=""
        while word != end_token:
            poem +=word
            x=np.zeros((1,1))
            x[0,0] = word_int_map[word]
            [predict,last_state] = sess.run([end_points['prediction'],end_points['last_state']],\
                                            feed_dict={input_data:x,end_points['initial_state']:last_state})

            word = to_word(predict,vocabularies)

        return poem
コード例 #2
0
ファイル: tang_poems.py プロジェクト: liaozhihui/lzh
def run_training():

    if not os.path.exists(os.path.dirname(FLAGS.checkpoints_dir)):
        os.mkdir(os.path.dirname(FLAGS.checkpoints_dir))

    if not os.path.exists(FLAGS.checkpoints_dir):
        os.mkdir(FLAGS.checkpoints_dir)

    poems_vector,word_to_int,vocabularies=process_poems(FLAGS.file_path)
    batch_inputs,batch_outputs=generate_batch(FLAGS.batch_size,poems_vector,word_to_int)

    input_data=tf.placeholder(tf.int32,[FLAGS.batch_size,None])
    output_targets=tf.placeholder(tf.int32,[FLAGS.batch_size,None])

    end_points=rnn_model(model="lstm",input_data=input_data,\
                         output_data=output_targets,vocab_size=len(vocabularies),\
                         rnn_size=128,num_layers=2,batch_size=64,learning_rate=FLAGS.learning_rate)

    saver=tf.train.Saver(tf.global_variables())
    init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        start_epoch=0
        checkpoint=tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
        if checkpoint:
            saver.restore(sess,checkpoint)
            print("[INFO] restore from the checkpoint{0}".format(checkpoint))
            start_epoch +=int(checkpoint.split('-')[-1])
        print("[INFO] start trianing...")

        try:
            for epoch in range(start_epoch,FLAGS.epoches):
                n=0
                n_chunk = len(poems_vector)//FLAGS.batch_size
                for batch in range(n_chunk):
                    loss,_,_=sess.run([end_points['total_loss'],end_points['last_state'],end_points['train_op']]\
                                      ,feed_dict={input_data:batch_inputs[n],output_targets:batch_outputs[n]})
                    n+=1
                    print("[INFO]Epoch:%d,batch:%d,training loss:%6f"%(epoch,batch,loss))
                if epoch %6==0:
                    saver.save(sess,os.path.join(FLAGS.checkpoints_dir,FLAGS.model_prefix),global_step=epoch)
        except KeyboardInterrupt:
            print("[INFO] Interrupt manually,try saving checkpoint for now...")
            saver.save(sess,os.path.join(FLAGS.checkpoints_dir,FLAGS.model_prefix),global_step=epoch)
            print("[INFO]Last epoch were saved,next time will start from epoch{}".format(epoch))
コード例 #3
0
def run_training():
    # 模型保存路径不存在则创建
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)

    # process_poems对古诗进行预处理
    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)

    # 占位向量
    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    # 使用lstm模型进行训练
    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=output_targets,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=FLAGS.batch_size,
                           learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        sess.run(init_op)
        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        # 如果之前训练过就找回之前的训练结果
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')
        try:
            n_chunk = len(poems_vector) // FLAGS.batch_size
            for epoch in range(start_epoch, FLAGS.epochs):
                #每次对其中的数据shuffle一次
                # print(type(poems_vector))
                random.shuffle(poems_vector)
                # print(type(poems_vector))
                # 这里有每一次输入的训练数据的列表
                batches_inputs, batches_outputs = generate_batch(
                    FLAGS.batch_size, poems_vector, word_to_int)
                n = 0
                for batch in range(n_chunk):
                    loss, _, _ = sess.run(
                        [
                            end_points['total_loss'], end_points['last_state'],
                            end_points['train_op']
                        ],
                        feed_dict={
                            input_data: batches_inputs[n],
                            output_targets: batches_outputs[n]
                        })
                    n += 1
                    print('Epoch: %d/%d, batch: %d/%d, training loss: %.6f' %
                          (epoch, FLAGS.epochs - 1, batch, n_chunk - 1, loss))
                if epoch % 6 == 0:
                    saver.save(sess,
                               os.path.join(FLAGS.model_dir,
                                            FLAGS.model_prefix),
                               global_step=epoch)
                alltime = time.time() - start_time
                print("Time: %d h %d min % ds" %
                      (alltime // 3600,
                       (alltime - alltime // 3600 * 3600) // 60, alltime % 60))
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess,
                       os.path.join(FLAGS.model_dir, FLAGS.model_prefix),
                       global_step=epoch)
            print(
                '## Last epoch were saved, next time will start from epoch {}.'
                .format(epoch))
コード例 #4
0
ファイル: main.py プロジェクト: SuperrrWu/deep-learning
def run_training():
    if not os.path.exists(os.path.dirname(FLAGS.checkpoints_dir)):
        os.mkdir(os.path.dirname(FLAGS.checkpoints_dir))
    if not os.path.exists(FLAGS.checkpoints_dir):
        os.mkdir(FLAGS.checkpoints_dir)
    # 单词转化的数字:向量,单词和数字一一对应的字典,单词
    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    # 真实值和目标值
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size,
                                                     poems_vector, word_to_int)
    # 数据占位符
    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=output_targets,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=64,
                           learning_rate=FLAGS.learning_rate)
    # 实例化保存模型
    saver = tf.train.Saver(tf.global_variables())
    # 全局变量进行初始化
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        # 先执行,全局变量初始化
        sess.run(init_op)

        start_epoch = 0
        # 把之前训练过的checkpoint拿出来
        checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
        if checkpoint:
            # 拿出训练保存模型
            saver.restore(sess, checkpoint)
            print("[INFO] restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('[INFO] start training...')
        try:
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                # 多少行唐诗//每次训练的个数
                n_chunk = len(poems_vector) // FLAGS.batch_size
                for batch in range(n_chunk):
                    loss, _, _ = sess.run(
                        [
                            end_points['total_loss'],  # 损失
                            end_points['last_state'],  # 最后一次输出
                            end_points['train_op']  # 训练优化损失
                        ],
                        feed_dict={
                            input_data: batches_inputs[n],
                            output_targets: batches_outputs[n]
                        })
                    n += 1
                    print(
                        '[INFO] Epoch: %d , batch: %d , training loss: %.6f' %
                        (epoch, batch, loss))
                if epoch % 6 == 0:  # 每隔多少次保存
                    saver.save(sess, FLAGS.checkpoints_dir, global_step=epoch)
        except KeyboardInterrupt:
            print(
                '[INFO] Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, FLAGS.checkpoints_dir, global_step=epoch)
            print(
                '[INFO] Last epoch were saved, next time will start from epoch {}.'
                .format(epoch))
コード例 #5
0
ファイル: main.py プロジェクト: SuperrrWu/deep-learning
def gen_poem(begin_words, num):
    batch_size = 1
    print('[INFO] loading corpus from %s' % FLAGS.file_path)
    # 单词转化的数字:向量,单词和数字一一对应的字典,单词
    poems_vector, word_int_map, vocabularies = process_poems(FLAGS.file_path)
    # 此时输入为1个
    input_data = tf.placeholder(tf.int32, [batch_size, None])
    # 损失等
    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=None,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=64,
                           learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        # 保存模型的位置,拿回sess
        checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
        # checkpoint = tf.train.latest_checkpoint('./model/')

        saver.restore(sess, checkpoint)
        # saver.restore(sess,'./model/-24')
        # 从字典里面获取到的开始值
        x = np.array([list(map(word_int_map.get, start_token))])

        [predict, last_state
         ] = sess.run([end_points['prediction'], end_points['last_state']],
                      feed_dict={input_data: x})
        poem = ''
        for begin_word in begin_words:

            while True:
                if begin_word:
                    word = begin_word
                else:
                    word = to_word(predict, vocabularies)
                sentence = ''
                while word != end_token:
                    sentence += word
                    x = np.zeros((1, 1))
                    x[0, 0] = word_int_map[word]
                    [predict, last_state] = sess.run(
                        [end_points['prediction'], end_points['last_state']],
                        feed_dict={
                            input_data: x,
                            end_points['initial_state']: last_state
                        })
                    word = to_word(predict, vocabularies)
                # word = words[np.argmax(probs_)]
                if len(sentence) == 2 + 2 * num and (',' or '?') not in sentence[:num] and (',' or '?') not in sentence[
                                                                                                               num + 1:-1] and \
                                sentence[num] == ',' and '□' not in sentence:
                    poem += sentence
                    # sentence = ''
                    break
                else:
                    print("我正在写诗呢")

        return poem
コード例 #6
0
ファイル: WritePoem.py プロジェクト: Jack1203187498/PoemLstm
def gen_poem(begin_word):
    batch_size = 1
    print('## loading corpus from %s' % model_dir)
    poems_vector, word_int_map, vocabularies = process_poems(corpus_file)

    input_data = tf.placeholder(tf.int32, [batch_size, None])

    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=None,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=64,
                           learning_rate=lr)
    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        checkpoint = tf.train.latest_checkpoint(model_dir)
        # print(checkpoint)
        saver.restore(sess, checkpoint)

        x = np.array([list(map(word_int_map.get, start_token))])

        [predict, last_state
         ] = sess.run([end_points['prediction'], end_points['last_state']],
                      feed_dict={input_data: x})
        poem_ = ''
        word = begin_word or to_word(predict, vocabularies, poem=poem_)

        i = 0
        # while not poem_ or ' ' in poem_:
        if 1:
            print("开始作诗")
            poem_ = ''
            while word != end_token and word != start_token:
                # while 1:
                # if word == ' ':
                #     poem_ =  gen_poem(begin_word)
                #     break
                poem_ += word
                print(poem_)
                i += 1
                # if i > 24:
                #     break
                x = np.array([[word_int_map[word]]])
                [predict, last_state] = sess.run(
                    [end_points['prediction'], end_points['last_state']],
                    feed_dict={
                        input_data: x,
                        end_points['initial_state']: last_state
                    })
                word = ' '
                # while word == ' ':
                if 1:
                    time.sleep(1)
                    random.seed(time.time())
                    word = to_word(predict, vocabularies, poem_)
        print('\n')
        return poem_
コード例 #7
0
ファイル: train.py プロジェクト: liuxinghui01/wolf-ai
def run_training():
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)

    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size,
                                                     poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = rnn_model(model='lstm',
                           input_data=input_data,
                           output_data=output_targets,
                           vocab_size=len(vocabularies),
                           rnn_size=128,
                           num_layers=2,
                           batch_size=64,
                           learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')
        try:
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                n_chunk = len(poems_vector) // FLAGS.batch_size
                for batch in range(n_chunk):
                    loss, _, _ = sess.run(
                        [
                            end_points['total_loss'], end_points['last_state'],
                            end_points['train_op']
                        ],
                        feed_dict={
                            input_data: batches_inputs[n],
                            output_targets: batches_outputs[n]
                        })
                    n += 1
                    print('Epoch: %d, batch: %d, training loss: %.6f' %
                          (epoch, batch, loss))
                if epoch % 6 == 0:
                    saver.save(sess,
                               os.path.join(FLAGS.model_dir,
                                            FLAGS.model_prefix),
                               global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess,
                       os.path.join(FLAGS.model_dir, FLAGS.model_prefix),
                       global_step=epoch)
            print(
                '## Last epoch were saved, next time will start from epoch {}.'
                .format(epoch))