Ejemplo n.º 1
0
    def _decoder_training(self, z):
        with tf.variable_scope('encoder', reuse=True):
            tied_embedding = tf.get_variable(
                'tied_embedding',
                [self.params['vocab_size'], args.embedding_dim])

        with tf.variable_scope('decoding'):
            lookback_cell = self._lookback_rnn_cell()
            lookback_state = lookback_cell.zero_state(self._batch_size,
                                                      tf.float32)
            lookback_state = (tf.layers.dense(self.z, args.rnn_size,
                                              tf.nn.relu), lookback_state[1])

            lin_proj = tf.layers.Dense(self.params['vocab_size'],
                                       _scope='decoder/dense')

            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=tf.nn.embedding_lookup(tied_embedding, self.dec_inp),
                sequence_length=self.dec_seq_len)
            decoder = BasicDecoder(cell=lookback_cell,
                                   helper=helper,
                                   initial_state=lookback_state,
                                   concat_z=self.z)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder)

        return decoder_output.rnn_output, lin_proj.apply(
            decoder_output.rnn_output)
Ejemplo n.º 2
0
    def generator(self, latent_vec, reuse=None, inference=False):
        with tf.variable_scope(self.scopes['E'], reuse=True):
            embedding = tf.get_variable(
                'embedding', [self.params['vocab_size'], args.embedding_dim])

        if not inference:
            with tf.variable_scope(self.scopes['G'], reuse=reuse):
                init_state = tf.layers.dense(latent_vec,
                                             args.rnn_size,
                                             tf.nn.elu,
                                             reuse=reuse)
                lin_proj = tf.layers.Dense(self.params['vocab_size'],
                                           _scope='decoder/dense',
                                           _reuse=reuse)

                helper = tf.contrib.seq2seq.TrainingHelper(
                    inputs=tf.nn.embedding_lookup(embedding, self.dec_inp),
                    sequence_length=self.dec_seq_len)
                decoder = BasicDecoder(cell=self.rnn_cell(reuse=reuse),
                                       helper=helper,
                                       initial_state=init_state,
                                       concat_z=latent_vec)
                decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder=decoder)
                return decoder_output.rnn_output, lin_proj.apply(
                    decoder_output.rnn_output)
        else:
            with tf.variable_scope(self.scopes['G'], reuse=True):
                init_state = tf.layers.dense(latent_vec,
                                             args.rnn_size,
                                             tf.nn.elu,
                                             reuse=True)

                decoder = BeamSearchDecoder(
                    cell=self.rnn_cell(reuse=True),
                    embedding=embedding,
                    start_tokens=tf.tile(
                        tf.constant([self.params['<start>']], dtype=tf.int32),
                        [self.batch_size]),
                    end_token=self.params['<end>'],
                    initial_state=tf.contrib.seq2seq.tile_batch(
                        init_state, args.beam_width),
                    beam_width=args.beam_width,
                    output_layer=tf.layers.Dense(self.params['vocab_size'],
                                                 _reuse=True),
                    concat_z=tf.tile(tf.expand_dims(latent_vec, 1),
                                     [1, args.beam_width, 1]))
                decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder=decoder,
                    maximum_iterations=2 * tf.reduce_max(self.enc_seq_len))
                return decoder_output.predicted_ids[:, :, 0]
Ejemplo n.º 3
0
    def _decoder_training(self, init_state):
        lin_proj = Dense(args.vocab_size, _scope='decoder/dense')

        helper = tf.contrib.seq2seq.TrainingHelper(
            inputs=tf.nn.embedding_lookup(self.tied_embedding,
                                          self._decoder_input()),
            sequence_length=self.seq_length + 1)
        decoder = BasicDecoder(cell=tf.nn.rnn_cell.MultiRNNCell(
            [self._rnn_cell() for _ in range(args.decoder_layers)]),
                               helper=helper,
                               initial_state=init_state,
                               z=self.z,
                               output_layer=None)
        decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder=decoder,
            impute_finished=True,
            maximum_iterations=tf.reduce_max(self.seq_length + 1))

        return decoder_output.rnn_output, lin_proj.apply(
            decoder_output.rnn_output)
Ejemplo n.º 4
0
    def _decoder_training(self, z):
        init_state = self._z_to_dec_state(z)

        with tf.variable_scope('encoder', reuse=True):
            tied_embedding = tf.get_variable(
                'tied_embedding', [args.vocab_size, args.embedding_dim])

        with tf.variable_scope('decoding'):
            lin_proj = tf.layers.Dense(args.vocab_size, _scope='decoder/dense')

            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=tf.nn.embedding_lookup(tied_embedding, self.dec_inp),
                sequence_length=self.dec_seq_len)
            decoder = BasicDecoder(cell=tf.nn.rnn_cell.MultiRNNCell(
                [self._rnn_cell() for _ in range(args.decoder_layers)]),
                                   helper=helper,
                                   initial_state=init_state,
                                   concat_z=self.z)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder)

        return decoder_output.rnn_output, lin_proj.apply(
            decoder_output.rnn_output)