예제 #1
0
def gen_name(begin_word):
    batch_size = 1
    print('## loading corpus from %s' % model_dir)
    names_vector, word_int_map, vocabularies = build_dataset(corpus_file)

    input_data = tf.placeholder(tf.int32, [batch_size, None])

    end_points = char_rnn(model='lstm',
                          input_data=input_data,
                          output_data=None,
                          vocab_size=len(vocabularies),
                          rnn_size=128,
                          num_layers=2,
                          batch_size=128,
                          learning_rate=lr)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        checkpoint = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess, checkpoint)

        x = np.array([list(map(word_int_map.get, start_token))])

        [predict, last_state
         ] = sess.run([end_points['prediction'], end_points['last_state']],
                      feed_dict={input_data: x})
        if begin_word:
            word = begin_word
        else:
            word = to_word(predict, vocabularies)
        name_ = ''

        i = 0
        while word != end_token:
            name_ += word
            i += 1
            if i >= 24:
                break
            x = np.zeros((1, 1))
            x[0, 0] = word_int_map[word]
            [predict, last_state
             ] = sess.run([end_points['prediction'], end_points['last_state']],
                          feed_dict={
                              input_data: x,
                              end_points['initial_state']: last_state
                          })
            word = to_word(predict, vocabularies)

        return name_
예제 #2
0
end_token = 'E'
model_dir = 'result/name'
corpus_file = 'data/names.txt'

lr = 0.0002

batch_size = 1
print('## loading corpus from %s' % model_dir)
names_vector, word_int_map, vocabularies = build_name_dataset(corpus_file)

input_data = tf.placeholder(tf.int32, [batch_size, None])

end_points = char_rnn(model='lstm',
                      input_data=input_data,
                      output_data=None,
                      vocab_size=len(vocabularies),
                      rnn_size=256,
                      num_layers=3,
                      batch_size=128,
                      learning_rate=lr)


def to_word(predict, vocabs):
    predict = predict[0]
    predict /= np.sum(predict)
    sample = np.random.choice(np.arange(len(predict)), p=predict)
    if sample > len(vocabs):
        return vocabs[-1]
    else:
        return vocabs[sample]

예제 #3
0
def train():

    # 创建结果保存的路径
    if not os.path.exists(FLAGS.result_dir):
        os.makedirs(FLAGS.result_dir)
    if FLAGS.model_prefix == 'poems':
        poems_vector, word_to_int, vocabularies = build_dataset(
            FLAGS.file_path)
    elif FLAGS.model_prefix == 'names':
        poems_vector, word_to_int, vocabularies = build_name_dataset(
            FLAGS.file_path)

    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size,
                                                     poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    end_points = char_rnn(model='lstm',
                          input_data=input_data,
                          output_data=output_targets,
                          vocab_size=len(vocabularies),
                          rnn_size=128,
                          num_layers=2,
                          batch_size=FLAGS.batch_size,
                          learning_rate=FLAGS.learning_rate)

    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.result_dir)
        if checkpoint:  # 从上次结束的地方继续训练
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')
        try:
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                n_chunk = len(poems_vector) // FLAGS.batch_size
                for batch in range(n_chunk):
                    loss, _, _ = sess.run(
                        [
                            end_points['total_loss'], end_points['last_state'],
                            end_points['train_op']
                        ],
                        feed_dict={
                            input_data: batches_inputs[n],
                            output_targets: batches_outputs[n]
                        })
                    n += 1
                    print('Epoch: %d, batch: %d, training loss: %.6f' %
                          (epoch, batch, loss))
                if epoch % 10 == 0:
                    saver.save(sess,
                               os.path.join(FLAGS.result_dir,
                                            FLAGS.model_prefix),
                               global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess,
                       os.path.join(FLAGS.result_dir, FLAGS.model_prefix),
                       global_step=epoch)
            print(
                '## Last epoch were saved, next time will start from epoch {}.'
                .format(epoch))
예제 #4
0
import tensorflow.compat.v1 as tf
from model import char_rnn, FLAGS
from utils import build_dataset
import numpy as np
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants
tf.disable_eager_execution()

FLAG = FLAGS()
poems_vector, word_int_map, vocabularies = build_dataset(
    FLAG.poems_path, FLAG.name_path)
input_data = tf.placeholder(tf.int32, [1, None])
end_points = char_rnn(model='lstm',
                      input_data=input_data,
                      output_data=None,
                      vocab_size=len(vocabularies),
                      rnn_size=FLAG.rnn_size,
                      num_layers=FLAG.num_layers,
                      batch_size=FLAG.batch_size,
                      learning_rate=FLAG.learning_rate)


def to_word(predict, vocabs):
    predict = predict[0]
    predict /= np.sum(predict)
    sample = np.random.choice(np.arange(len(predict)), p=predict)
    if sample > len(vocabs):
        return vocabs[-1]
    else:
        return vocabs[sample]

예제 #5
0
def train():
    FLAG = FLAGS()
    poems_vector, word_to_int, vocabularies = build_dataset(
        FLAG.poems_path, FLAG.name_path)

    batches_inputs, batches_outputs = generate_batch(FLAG.batch_size,
                                                     poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAG.batch_size, None],
                                name="Input")
    output_targets = tf.placeholder(tf.int32, [FLAG.batch_size, None])
    #z = tf.log(output_targets, name="namemodel")
    end_points = char_rnn(model='lstm',
                          input_data=input_data,
                          output_data=output_targets,
                          vocab_size=len(vocabularies),
                          rnn_size=FLAG.rnn_size,
                          num_layers=FLAG.num_layers,
                          batch_size=FLAG.batch_size,
                          learning_rate=FLAG.learning_rate)
    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAG.result_dir)
        if checkpoint:  # 从上次结束的地方继续训练
            saver.restore(sess, checkpoint)
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')

        try:
            for epoch in range(start_epoch, FLAG.epochs):
                n = 0
                n_chunk = len(poems_vector) // FLAG.batch_size
                for batch in range(n_chunk):
                    loss, _, _ = sess.run(
                        [
                            end_points['total_loss'], end_points['last_state'],
                            end_points['train_op']
                        ],
                        feed_dict={
                            input_data: batches_inputs[n],
                            output_targets: batches_outputs[n]
                        })
                    n += 1
                    print('Epoch: %d, batch: %d, training loss: %.6f' %
                          (epoch, batch, loss))
                if epoch % 10 == 0:
                    saver.save(sess,
                               os.path.join(FLAG.result_dir,
                                            FLAG.model_prefix),
                               global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess,
                       os.path.join(FLAG.result_dir, FLAG.model_prefix),
                       global_step=epoch)
            print(
                '## Last epoch were saved, next time will start from epoch {}.'
                .format(epoch))
        #saver.save(sess, FLAG.result_dir+'/model/'+"model.ckpt")
        #tf.train.write_graph(sess.graph_def, FLAG.result_dir+'/model/', 'graph.pb')

        builder = tf.saved_model.builder.SavedModelBuilder(FLAG.result_dir +
                                                           "/model_complex")
        SignatureDef = tf.saved_model.signature_def_utils.build_signature_def(
            inputs={
                'input_data':
                tf.saved_model.utils.build_tensor_info(input_data),
                'output_targets':
                tf.saved_model.utils.build_tensor_info(output_targets)
            },
            outputs={
                'prediction':
                tf.saved_model.utils.build_tensor_info(
                    end_points['prediction'])
            })
        builder.add_meta_graph_and_variables(
            sess, [tag_constants.TRAINING],
            signature_def_map={
                tf.saved_model.signature_constants.CLASSIFY_INPUTS:
                SignatureDef
            })
        builder.save()