Пример #1
0
class Model(object):
    def __init__(self):
        self.story = Story()
        # 初始的句子向量
        self.vocab = self.story.vocab
        self.batch_size = self.story.batch_size - 2  #126
        self.chunk_size = self.story.chunk_size
        self.embedding_dim = 300
        self.num_units = 500
        self.learning_rate = 0.001
        self.epoch = 25
        self.sample_size = 50

    def gru_encoder(self, encode_emb, length, train=True):
        batch_size = self.batch_size if train else 1
        with tf.variable_scope('encoder'):
            cell = tf.nn.rnn_cell.GRUCell(num_units=self.num_units)
            inital_state = cell.zero_state(batch_size, tf.float32)
            _, final_state = tf.nn.dynamic_rnn(cell,
                                               encode_emb,
                                               initial_state=inital_state,
                                               sequence_length=length)
        return inital_state, final_state

    def softmax_variable(self, num_units, vocab_size, reuse=False):
        with tf.variable_scope('softmax_variable', reuse=reuse):
            w = tf.get_variable('w', [num_units, vocab_size])
            b = tf.get_variable('b', [vocab_size])
        return w, b

    def gru_decoder(self, decode_emb, length, state, scope, reuse=False):
        with tf.variable_scope(scope):
            cell = tf.nn.rnn_cell.GRUCell(num_units=self.num_units)
            outputs, final_state = tf.nn.dynamic_rnn(cell,
                                                     decode_emb,
                                                     initial_state=state,
                                                     sequence_length=length)
        x = tf.reshape(outputs, [-1, self.num_units])
        w, b = self.softmax_variable(self.num_units,
                                     len(self.vocab),
                                     reuse=reuse)
        logits = tf.matmul(x, w) + b
        prediction = tf.nn.softmax(logits, name='redictions')
        return logits, prediction, final_state

    def _loss(self, logits, targets, scope='loss'):
        with tf.variable_scope(scope):
            y_one_hot = tf.one_hot(targets, len(self.vocab))
            y_reshaped = tf.reshape(y_one_hot, [-1, len(self.vocab)])
            loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                        labels=y_reshaped))
        return loss

    def _optimizer(self, loss, scope='optimizer'):
        with tf.variable_scope(scope):
            grad_clip = 5
            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars),
                                              grad_clip)
            op = tf.train.AdamOptimizer(self.learning_rate)
            optimizer = op.apply_gradients(zip(grads, tvars))
        return optimizer

    def _inputs(self):
        with tf.variable_scope('inputs'):
            self.encode_length = tf.placeholder(tf.int32,
                                                shape=[
                                                    None,
                                                ],
                                                name='encode_length')
            self.decode_pre_x = tf.placeholder(tf.int32,
                                               shape=[None, None],
                                               name='decode_pre_x')
            self.decode_pre_y = tf.placeholder(tf.int32,
                                               shape=[None, None],
                                               name='decode_pre_y')
            self.decode_pre_length = tf.placeholder(tf.int32,
                                                    shape=[
                                                        None,
                                                    ],
                                                    name='decode_pre_length')
            self.decode_post_x = tf.placeholder(tf.int32,
                                                shape=[None, None],
                                                name='decode_post_x')
            self.decode_post_y = tf.placeholder(tf.int32,
                                                shape=[None, None],
                                                name='decode_post_y')
            self.decode_post_length = tf.placeholder(tf.int32,
                                                     shape=[
                                                         None,
                                                     ],
                                                     name='decode_post_length')

    def _embedding(self):
        with tf.variable_scope('embedding'):
            self.embedding = tf.get_variable(
                name='embedding',
                shape=[len(self.vocab), self.embedding_dim],
                initializer=tf.random_uniform_initializer(-0.1, 0.1))
            self.encode_emb = tf.nn.embedding_lookup(self.embedding,
                                                     self.encode,
                                                     name='encode_emb')
            self.decode_pre_emb = tf.nn.embedding_lookup(self.embedding,
                                                         self.decode_pre_x,
                                                         name='decode_pre_emb')
            self.decode_post_emb = tf.nn.embedding_lookup(
                self.embedding, self.decode_post_x, name='decode_post_emb')

    def build_model(self):
        self._inputs()
        self._embedding()
        self.initial_state, self.final_state = self.gru_encoder(
            self.encode_emb, self.encode_length)
        self.pre_logits, self.pre_prediction, self.pre_state = self.gru_decoder(
            self.decode_pre_emb,
            self.decode_pre_length,
            self.final_state,
            scope='decoder_pre')
        self.post_logits, self.post_prediction, self.post_state = self.gru_decoder(
            self.decode_post_emb,
            self.decode_post_length,
            self.final_state,
            scope='decoder_post',
            reuse=True)

        self.pre_loss = self._loss(self.pre_logits,
                                   self.decode_pre_y,
                                   scope='decoder_pre_loss')
        self.pre_loss_sum = scalar_summary("pre_loss", self.pre_loss)
        self.post_loss = self._loss(self.post_logits,
                                    self.decode_post_y,
                                    scope='decoder_post_loss')
        self.post_loss_sum = scalar_summary("post_loss", self.post_loss)

        self.pre_optimizer = self._optimizer(self.pre_loss,
                                             scope='decoder_pre_op')
        self.post_optimizer = self._optimizer(self.post_loss,
                                              scope='decoder_post_op')

    def train(self):
        model_path = './output/skipThought.model'
        self.build_model()
        self.saver = tf.train.Saver()

        with tf.Session() as sess:
            self.writer = SummaryWriter("./output/logs", sess.graph)
            self._sum = merge_summary([self.pre_loss_sum, self.post_loss_sum])
            step = 0
            sess.run(tf.global_variables_initializer())
            new_state = sess.run(self.initial_state)

            for epoch in range(self.epoch):
                batches = self.story.batch()
                for encode_x, decode_pre_x, decode_pre_y, \
                    decode_post_x, decode_post_y, encode_length, \
                    decode_pre_length, decode_post_length in batches:
                    if len(encode_x) != self.batch_size: continue
                    feed = {
                        self.initial_state: new_state,
                        self.encode: encode_x,
                        self.encode_length: encode_length,
                        self.decode_pre_x: decode_pre_x,
                        self.decode_pre_y: decode_pre_y,
                        self.decode_pre_length: decode_pre_length,
                        self.decode_post_x: decode_post_x,
                        self.decode_post_y: decode_post_y,
                        self.decode_post_length: decode_post_length
                    }
                    _, pre_loss, _, _, post_loss, new_state, summary_str = sess.run(
                        [
                            self.pre_optimizer, self.pre_loss, self.pre_state,
                            self.post_optimizer, self.post_loss,
                            self.post_state, self._sum
                        ],
                        feed_dict=feed)
                    self.writer.add_summary(summary_str, step)

                    print(' epoch:', epoch, ' step:', step, ' pre_loss',
                          pre_loss, ' post_loss', post_loss)
                    step += 1
                self.saver.save(sess, model_path, global_step=step)

    def gen(self):
        self._inputs()

        self._embedding()
        self.initial_state, self.final_state = self.gru_encoder(
            self.encode_emb, self.encode_length, train=False)

        self.post_logits, self.post_prediction, self.post_state = self.gru_decoder(
            self.decode_post_emb,
            self.decode_post_length,
            self.final_state,
            scope='decoder_post')
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            new_state = sess.run(self.initial_state)
            saver.restore(sess, tf.train.latest_checkpoint('./output/'))
            encode_x = [[self.story.word_to_int[c] for c in '宝玉归来家中']]
            samples = [[] for _ in range(self.sample_size)]
            samples[0] = encode_x[0]
            for i in range(self.sample_size):
                decode_x = [[self.story.word_to_int['<GO>']]]
                while decode_x[0][-1] != self.story.word_to_int['<EOS>']:
                    feed = {
                        self.encode: encode_x,
                        self.encode_length: [len(encode_x[0])],
                        self.initial_state: new_state,
                        self.decode_post_x: decode_x,
                        self.decode_post_length: [len(decode_x[0])]
                    }
                    predict, state = sess.run(
                        [self.post_prediction, self.post_state],
                        feed_dict=feed)
                    int_word = np.argmax(predict, 1)[-1]
                    decode_x[0] += [int_word]
                samples[i] += decode_x[0][1:-1]
                encode_x = [samples[i]]
                new_state = state
                print(''.join(
                    [self.story.int_to_word[sample] for sample in samples[i]]))