Exemplo n.º 1
0
#-*- coding:utf-8 -*-
from generate_poetry import Poetry
from poetry_model import poetryModel
import tensorflow as tf
import numpy as np

if __name__ == '__main__':
    batch_size = 50
    epoch = 20
    rnn_size = 128
    num_layers = 2
    poetrys = Poetry()
    words_size = len(poetrys.word_to_id)
    inputs = tf.placeholder(tf.int32, [batch_size, None])
    targets = tf.placeholder(tf.int32, [batch_size, None])
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    model = poetryModel()
    logits, probs, initial_state, last_state = model.create_model(
        inputs, batch_size, rnn_size, words_size, num_layers, True, keep_prob)
    loss = model.loss_model(words_size, targets, logits)
    learning_rate = tf.Variable(0.0, trainable=False)
    optimizer = model.optimizer_model(loss, learning_rate)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.assign(learning_rate, 0.002 * 0.97))
        next_state = sess.run(initial_state)
        step = 0
        while True:
            x_batch, y_batch = poetrys.next_batch(batch_size)
            feed = {
Exemplo n.º 2
0
#-*- coding:utf-8 -*-
from generate_poetry import Poetry
from poetry_model import poetryModel
from operator import itemgetter
import tensorflow as tf
import numpy as np
import random

if __name__ == '__main__':
    batch_size = 1
    rnn_size = 128
    num_layers = 2
    poetrys = Poetry()
    words_size = len(poetrys.word_to_id)

    def to_word(prob):
        prob = prob[0]
        indexs, _ = zip(*sorted(enumerate(prob), key=itemgetter(1)))
        rand_num = int(np.random.rand(1) * 10)
        index_sum = len(indexs)
        max_rate = prob[indexs[(index_sum - 1)]]
        if max_rate > 0.9:
            sample = indexs[(index_sum - 1)]
        else:
            sample = indexs[(index_sum - 1 - rand_num)]
        return poetrys.id_to_word[sample]

    inputs = tf.placeholder(tf.int32, [batch_size, None])
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    model = poetryModel()
    logits, probs, initial_state, last_state = model.create_model(