Example #1
0
class CVAE(object):
    def __init__(self, tfFLAGS, embed=None):
        self.vocab_size = tfFLAGS.vocab_size
        self.embed_size = tfFLAGS.embed_size
        self.num_units = tfFLAGS.num_units
        self.num_layers = tfFLAGS.num_layers
        self.beam_width = tfFLAGS.beam_width
        self.use_lstm = tfFLAGS.use_lstm
        self.attn_mode = tfFLAGS.attn_mode
        self.train_keep_prob = tfFLAGS.keep_prob
        self.max_decode_len = tfFLAGS.max_decode_len
        self.bi_encode = tfFLAGS.bi_encode
        self.recog_hidden_units = tfFLAGS.recog_hidden_units
        self.prior_hidden_units = tfFLAGS.prior_hidden_units
        self.z_dim = tfFLAGS.z_dim
        self.full_kl_step = tfFLAGS.full_kl_step

        self.global_step = tf.Variable(0, name="global_step", trainable=False)
        self.max_gradient_norm = 5.0
        if tfFLAGS.opt == 'SGD':
            self.learning_rate = tf.Variable(float(tfFLAGS.learning_rate),
                                             trainable=False,
                                             dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                self.learning_rate * tfFLAGS.learning_rate_decay_factor)
            self.opt = tf.train.GradientDescentOptimizer(self.learning_rate)
        elif tfFLAGS.opt == 'Momentum':
            self.opt = tf.train.MomentumOptimizer(
                learning_rate=tfFLAGS.learning_rate, momentum=tfFLAGS.momentum)
        else:
            self.learning_rate = tfFLAGS.learning_rate
            self.opt = tf.train.AdamOptimizer()

        self._make_input(embed)

        with tf.variable_scope("output_layer"):
            self.output_layer = Dense(
                self.vocab_size,
                kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))

        with tf.variable_scope("encoders",
                               initializer=tf.orthogonal_initializer()):
            self.enc_post_outputs, self.enc_post_state = self._build_encoder(
                scope='post_encoder',
                inputs=self.enc_post,
                sequence_length=self.post_len)
            self.enc_ref_outputs, self.enc_ref_state = self._build_encoder(
                scope='ref_encoder',
                inputs=self.enc_ref,
                sequence_length=self.ref_len)
            self.enc_response_outputs, self.enc_response_state = self._build_encoder(
                scope='resp_encoder',
                inputs=self.enc_response,
                sequence_length=self.response_len)

            self.post_state = self._get_representation_from_enc_state(
                self.enc_post_state)
            self.ref_state = self._get_representation_from_enc_state(
                self.enc_ref_state)
            self.response_state = self._get_representation_from_enc_state(
                self.enc_response_state)
            self.cond_embed = tf.concat([self.post_state, self.ref_state],
                                        axis=-1)

        with tf.variable_scope("RecognitionNetwork"):
            recog_input = tf.concat([self.cond_embed, self.response_state],
                                    axis=-1)
            recog_hidden = tf.layers.dense(inputs=recog_input,
                                           units=self.recog_hidden_units,
                                           activation=tf.nn.tanh)
            recog_mulogvar = tf.layers.dense(inputs=recog_hidden,
                                             units=self.z_dim * 2,
                                             activation=None)
            # recog_mulogvar = tf.layers.dense(inputs=recog_input, units=self.z_dim * 2, activation=None)
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=-1)

        with tf.variable_scope("PriorNetwork"):
            prior_input = self.cond_embed
            prior_hidden = tf.layers.dense(inputs=prior_input,
                                           units=self.prior_hidden_units,
                                           activation=tf.nn.tanh)
            prior_mulogvar = tf.layers.dense(inputs=prior_hidden,
                                             units=self.z_dim * 2,
                                             activation=None)
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=-1)

        with tf.variable_scope("GenerationNetwork"):
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar),
                name='latent_sample')

            gen_input = tf.concat([self.cond_embed, latent_sample], axis=-1)
            if self.use_lstm:
                self.dec_init_state = tuple([
                    tf.contrib.rnn.LSTMStateTuple(
                        c=tf.layers.dense(inputs=gen_input,
                                          units=self.num_units,
                                          activation=None),
                        h=tf.layers.dense(inputs=gen_input,
                                          units=self.num_units,
                                          activation=None))
                    for _ in range(self.num_layers)
                ])
                print self.dec_init_state
            else:
                self.dec_init_state = tuple([
                    tf.layers.dense(inputs=gen_input,
                                    units=self.num_units,
                                    activation=None)
                    for _ in range(self.num_layers)
                ])

            kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar)
            self.avg_kld = tf.reduce_mean(kld)
            self.kl_weights = tf.minimum(
                tf.to_float(self.global_step) / self.full_kl_step, 1.0)
            self.kl_loss = self.kl_weights * self.avg_kld

        self._build_decoder()
        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                    max_to_keep=1,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
        for var in tf.trainable_variables():
            print var

    def _make_input(self, embed):
        self.symbol2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=UNK_ID,
                                             shared_name="in_table",
                                             name="in_table",
                                             checkpoint=True)
        self.index2symbol = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value=_UNK,
                                             shared_name="out_table",
                                             name="out_table",
                                             checkpoint=True)
        with tf.variable_scope("input"):
            self.post_string = tf.placeholder(tf.string, (None, None),
                                              'post_string')
            self.ref_string = tf.placeholder(tf.string, (None, None),
                                             'ref_string')
            self.response_string = tf.placeholder(tf.string, (None, None),
                                                  'response_string')

            self.post = self.symbol2index.lookup(self.post_string)
            self.post_len = tf.placeholder(tf.int32, (None, ), 'post_len')
            self.ref = self.symbol2index.lookup(self.ref_string)
            self.ref_len = tf.placeholder(tf.int32, (None, ), 'ref_len')
            self.response = self.symbol2index.lookup(self.response_string)
            self.response_len = tf.placeholder(tf.int32, (None, ),
                                               'response_len')

            with tf.variable_scope("embedding") as scope:
                if embed is None:
                    # initialize the embedding randomly
                    self.emb_enc = self.emb_dec = tf.get_variable(
                        "emb_share", [self.vocab_size, self.embed_size],
                        dtype=tf.float32)
                else:
                    # initialize the embedding by pre-trained word vectors
                    print "share pre-trained embed"
                    self.emb_enc = self.emb_dec = tf.get_variable(
                        'emb_share', dtype=tf.float32, initializer=embed)

            self.enc_post = tf.nn.embedding_lookup(self.emb_enc, self.post)
            self.enc_ref = tf.nn.embedding_lookup(self.emb_enc, self.ref)
            self.enc_response = tf.nn.embedding_lookup(self.emb_enc,
                                                       self.response)

            self.batch_len = tf.shape(self.response)[1]
            self.batch_size = tf.shape(self.response)[0]
            self.response_input = tf.concat([
                tf.ones((self.batch_size, 1), dtype=tf.int64) * GO_ID,
                tf.split(self.response, [self.batch_len - 1, 1], axis=1)[0]
            ], 1)
            self.dec_inp = tf.nn.embedding_lookup(self.emb_dec,
                                                  self.response_input)

            self.keep_prob = tf.placeholder_with_default(1.0, ())
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

    def _build_encoder(self, scope, inputs, sequence_length):
        with tf.variable_scope(scope):
            if self.bi_encode:
                cell_fw, cell_bw = self._build_biencoder_cell()
                outputs, states = tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=cell_fw,
                    cell_bw=cell_bw,
                    inputs=inputs,
                    sequence_length=sequence_length,
                    dtype=tf.float32)
                enc_outputs = tf.concat(outputs, axis=-1)
                enc_state = []
                for i in range(self.num_layers):
                    if self.use_lstm:
                        encoder_state_c = tf.concat(
                            [states[0][i].c, states[1][i].c], axis=-1)
                        encoder_state_h = tf.concat(
                            [states[0][i].h, states[1][i].h], axis=-1)
                        enc_state.append(
                            tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c,
                                                          h=encoder_state_h))
                    else:
                        enc_state.append(
                            tf.concat([states[0][i], states[1][i]], axis=-1))
                enc_state = tuple(enc_state)
                return enc_outputs, enc_state
            else:
                enc_cell = self._build_encoder_cell()
                enc_outputs, enc_state = tf.nn.dynamic_rnn(
                    cell=enc_cell,
                    inputs=inputs,
                    sequence_length=sequence_length,
                    dtype=tf.float32)
                return enc_outputs, enc_state

    def _get_representation_from_enc_state(self, enc_state):
        if self.use_lstm:
            return tf.concat([state.h for state in enc_state], axis=-1)
        else:
            return tf.concat(enc_state, axis=-1)

    def _build_decoder(self):
        with tf.variable_scope("decode",
                               initializer=tf.orthogonal_initializer()):
            dec_cell, init_state = self._build_decoder_cell(
                self.enc_post_outputs, self.post_len, self.dec_init_state)

            train_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=self.dec_inp, sequence_length=self.response_len)
            train_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=dec_cell,
                helper=train_helper,
                initial_state=init_state,
                output_layer=self.output_layer)
            train_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=train_decoder,
                maximum_iterations=self.max_decode_len,
            )
            logits = train_output.rnn_output

            mask = tf.sequence_mask(self.response_len,
                                    self.batch_len,
                                    dtype=tf.float32)

            crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=self.response, logits=logits)
            crossent = tf.reduce_sum(crossent * mask)
            self.sen_loss = crossent / tf.to_float(self.batch_size)

            # ppl(loss avg) across each timestep, the same as :
            # self.loss = tf.contrib.seq2seq.sequence_loss(train_output.rnn_output,
            #                                              self.response,
            #                                              mask)
            self.ppl_loss = crossent / tf.reduce_sum(mask)

            # add kld:
            self.elbo = self.sen_loss + self.kl_loss

            # Calculate and clip gradients
            params = tf.trainable_variables()
            gradients = tf.gradients(self.elbo, params)
            clipped_gradients, _ = tf.clip_by_global_norm(
                gradients, self.max_gradient_norm)
            self.train_op = self.opt.apply_gradients(
                zip(clipped_gradients, params), global_step=self.global_step)

            self.train_out = self.index2symbol.lookup(tf.cast(
                train_output.sample_id, tf.int64),
                                                      name='train_out')

        with tf.variable_scope("decode", reuse=True):
            dec_cell, init_state = self._build_decoder_cell(
                self.enc_post_outputs, self.post_len, self.dec_init_state)

            start_tokens = tf.tile(tf.constant([GO_ID], dtype=tf.int32),
                                   [self.batch_size])
            end_token = EOS_ID
            infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                self.emb_dec, start_tokens, end_token)
            infer_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=dec_cell,
                helper=infer_helper,
                initial_state=init_state,
                output_layer=self.output_layer)
            infer_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=infer_decoder,
                maximum_iterations=self.max_decode_len,
            )

            self.inference = self.index2symbol.lookup(tf.cast(
                infer_output.sample_id, tf.int64),
                                                      name='inference')

        with tf.variable_scope("decode", reuse=True):
            dec_init_state = tf.contrib.seq2seq.tile_batch(
                self.dec_init_state, self.beam_width)
            enc_outputs = tf.contrib.seq2seq.tile_batch(
                self.enc_post_outputs, self.beam_width)
            post_len = tf.contrib.seq2seq.tile_batch(self.post_len,
                                                     self.beam_width)

            dec_cell, init_state = self._build_decoder_cell(
                enc_outputs,
                post_len,
                dec_init_state,
                beam_width=self.beam_width)

            beam_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                cell=dec_cell,
                embedding=self.emb_dec,
                start_tokens=tf.ones_like(self.post_len) * GO_ID,
                end_token=EOS_ID,
                initial_state=init_state,
                beam_width=self.beam_width,
                output_layer=self.output_layer)
            beam_output, _, beam_lengths = tf.contrib.seq2seq.dynamic_decode(
                decoder=beam_decoder,
                maximum_iterations=self.max_decode_len,
            )

            self.beam_out = self.index2symbol.lookup(tf.cast(
                beam_output.predicted_ids, tf.int64),
                                                     name='beam_out')

    def _build_encoder_cell(self):
        if self.use_lstm:
            cell = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.LSTMCell(self.num_units), self.keep_prob)
                for _ in range(self.num_layers)
            ])
        else:
            cell = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.GRUCell(self.num_units), self.keep_prob)
                for _ in range(self.num_layers)
            ])
        return cell

    def _build_biencoder_cell(self):
        if self.use_lstm:
            cell_fw = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.LSTMCell(self.num_units / 2),
                    self.keep_prob) for _ in range(self.num_layers)
            ])
            cell_bw = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.LSTMCell(self.num_units / 2),
                    self.keep_prob) for _ in range(self.num_layers)
            ])
        else:
            cell_fw = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.GRUCell(self.num_units / 2), self.keep_prob)
                for _ in range(self.num_layers)
            ])
            cell_bw = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.GRUCell(self.num_units / 2), self.keep_prob)
                for _ in range(self.num_layers)
            ])
        return cell_fw, cell_bw

    def _build_decoder_cell(self,
                            memory,
                            memory_len,
                            encode_state,
                            beam_width=1):
        if self.use_lstm:
            cell = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.LSTMCell(self.num_units), self.keep_prob)
                for _ in range(self.num_layers)
            ])
        else:
            cell = tf.contrib.rnn.MultiRNNCell([
                tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.GRUCell(self.num_units), self.keep_prob)
                for _ in range(self.num_layers)
            ])
        if self.attn_mode == 'Luong':
            attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                num_units=self.num_units,
                memory=memory,
                memory_sequence_length=memory_len,
                scale=True)
        elif self.attn_mode == 'Bahdanau':
            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                num_units=self.num_units,
                memory=memory,
                memory_sequence_length=memory_len,
                scale=True)
        else:
            return cell, encode_state
        attn_cell = tf.contrib.seq2seq.AttentionWrapper(
            cell=cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=self.num_units,
        )
        return attn_cell, attn_cell.zero_state(
            self.batch_size * beam_width,
            tf.float32).clone(cell_state=encode_state)

    def initialize(self, sess, vocab):
        op_in = self.symbol2index.insert(
            constant_op.constant(vocab),
            constant_op.constant(range(len(vocab)), dtype=tf.int64))
        op_out = self.index2symbol.insert(
            constant_op.constant(range(len(vocab)), dtype=tf.int64),
            constant_op.constant(vocab))
        sess.run(tf.global_variables_initializer())
        sess.run([op_in, op_out])

    def step(self, sess, data, is_train=False):
        input_feed = {
            self.post_string: data['post'],
            self.post_len: data['post_len'],
            self.ref_string: data['ref'],
            self.ref_len: data['ref_len'],
            self.response_string: data['response'],
            self.response_len: data['response_len'],
            self.use_prior: is_train,
        }
        if is_train:
            output_feed = [
                self.train_op,
                self.ppl_loss,
                self.elbo,
                self.sen_loss,
                self.kl_loss,
                self.avg_kld,
                self.kl_weights,
                # self.post_string,
                # self.response_string,
                # self.train_out,
                # self.inference,
                # self.beam_out,
            ]
            input_feed[self.keep_prob] = self.train_keep_prob
        else:
            output_feed = [
                self.ppl_loss,
                self.elbo,
                self.sen_loss,
                self.kl_loss,
                self.avg_kld,
                self.kl_weights,
                # self.post_string,
                # self.response_string,
                # self.train_out,
                # self.inference,
                # self.beam_out,
            ]
        return sess.run(output_feed, input_feed)
Example #2
0
class Model(object):
    def __init__(self,
                 word_embed,
                 entity_embed,
                 vocab_size=30000,
                 num_embed_units=300,
                 num_units=512,
                 num_layers=2,
                 num_entities=0,
                 num_trans_units=100,
                 max_length=60,
                 learning_rate=0.0001,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 num_samples=500,
                 output_alignments=True):
        # initialize params
        self.vocab_size = vocab_size
        self.num_embed_units = num_embed_units
        self.num_units = num_units
        self.num_layers = num_layers
        self.num_entities = num_entities
        self.num_trans_units = num_trans_units
        self.learning_rate = learning_rate
        self.max_gradient_norm = max_gradient_norm
        self.num_samples = num_samples
        self.max_length = max_length
        self.output_alignments = output_alignments

        # build the embedding table (index to vector)
        if word_embed is None:
            # initialize the embedding randomly
            self.word_embed = tf.get_variable(
                'word_embed', [self.vocab_size, self.num_embed_units],
                tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.word_embed = tf.get_variable('word_embed',
                                              dtype=tf.float32,
                                              initializer=word_embed)
        if entity_embed is None:
            # initialize the embedding randomly
            self.entity_trans = tf.get_variable(
                'entity_embed', [num_entities, num_trans_units],
                tf.float32,
                trainable=False)
        else:
            # initialize the embedding by pre-trained trans vectors
            self.entity_trans = tf.get_variable('entity_embed',
                                                dtype=tf.float32,
                                                initializer=entity_embed,
                                                trainable=False)

        # initialize inputs and outputs
        self.posts = tf.placeholder(tf.string, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None),
                                           'enc_lens')  # batch
        self.responses = tf.placeholder(tf.string, (None, None),
                                        'dec_inps')  # batch*len
        self.responses_length = tf.placeholder(tf.int32, (None),
                                               'dec_lens')  # batch
        self.entities = tf.placeholder(tf.string, (None, None, None),
                                       'entities')  # batch
        self.entity_masks = tf.placeholder(tf.string, (None, None),
                                           'entity_masks')  # batch
        self.triples = tf.placeholder(tf.string, (None, None, None, 3),
                                      'triples')  # batch
        self.posts_triple = tf.placeholder(tf.int32, (None, None, 1),
                                           'enc_triples')  # batch
        self.responses_triple = tf.placeholder(tf.string, (None, None, 3),
                                               'dec_triples')  # batch
        self.match_triples = tf.placeholder(tf.int32, (None, None, None),
                                            'match_triples')  # batch
        self._init_vocabs()

        # build the vocab table (string to index)
        self.posts_word_id = self.symbol2index.lookup(self.posts)  # batch*len
        self.posts_entity_id = self.entity2index.lookup(
            self.posts)  # batch*len
        self.responses_target = self.symbol2index.lookup(
            self.responses)  # batch*len
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        self.responses_word_id = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # build entity embeddings
        entity_trans_transformed = tf.layers.dense(self.entity_trans,
                                                   self.num_trans_units,
                                                   activation=tf.tanh,
                                                   name='trans_transformation')
        padding_entity = tf.get_variable('entity_padding_embed',
                                         [7, self.num_trans_units],
                                         dtype=tf.float32,
                                         initializer=tf.zeros_initializer())
        self.entity_embed = tf.concat(
            [padding_entity, entity_trans_transformed], axis=0)

        # get knowledge graph embedding, knowledge triple embedding
        self.triples_embedding, self.entities_word_embedding, self.graph_embedding = self._build_kg_embedding(
        )

        # build knowledge graph
        graph_embed_input, triple_embed_input = self._build_kg_graph()

        # build encoder
        encoder_output, encoder_state = self._build_encoder(graph_embed_input)

        # build decoder
        self._build_decoder(encoder_output, encoder_state, triple_embed_input)

        # initialize training process
        self.global_step = tf.Variable(0, trainable=False)
        self.params = tf.global_variables()

        gradients = tf.gradients(self.decoder_loss, self.params)
        self.clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, self.max_gradient_norm)
        optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        self.update = optimizer.apply_gradients(zip(self.clipped_gradients,
                                                    self.params),
                                                global_step=self.global_step)

        tf.summary.scalar('decoder_loss', self.decoder_loss)
        for each in tf.trainable_variables():
            tf.summary.histogram(each.name, each)
        self.merged_summary_op = tf.summary.merge_all()

    def _init_vocabs(self):
        self.symbol2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=UNK_ID,
                                             shared_name="in_table",
                                             name="in_table",
                                             checkpoint=True)
        self.index2symbol = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_UNK',
                                             shared_name="out_table",
                                             name="out_table",
                                             checkpoint=True)
        self.entity2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=NONE_ID,
                                             shared_name="entity_in_table",
                                             name="entity_in_table",
                                             checkpoint=True)
        self.index2entity = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_NONE',
                                             shared_name="entity_out_table",
                                             name="entity_out_table",
                                             checkpoint=True)

    def _build_kg_embedding(self):
        encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts))
        triple_num = tf.shape(self.triples)[1]

        triples_embedding = tf.reshape(
            tf.nn.embedding_lookup(self.entity_embed,
                                   self.entity2index.lookup(self.triples)),
            [encoder_batch_size, triple_num, -1, 3 * self.num_trans_units])

        entities_word_embedding = tf.reshape(
            tf.nn.embedding_lookup(self.word_embed,
                                   self.symbol2index.lookup(self.entities)),
            [encoder_batch_size, -1, self.num_embed_units])

        head, relation, tail = tf.split(triples_embedding,
                                        [self.num_trans_units] * 3,
                                        axis=3)
        with tf.variable_scope('graph_attention', reuse=tf.AUTO_REUSE):
            head_tail = tf.concat([head, tail], axis=3)
            head_tail_transformed = tf.layers.dense(head_tail,
                                                    self.num_trans_units,
                                                    activation=tf.tanh,
                                                    name='head_tail_transform')
            relation_transformed = tf.layers.dense(relation,
                                                   self.num_trans_units,
                                                   name='relation_transform')
            e_weight = tf.reduce_sum(relation_transformed *
                                     head_tail_transformed,
                                     axis=3)
            alpha_weight = tf.nn.softmax(e_weight)
            graph_embedding = tf.reduce_sum(tf.expand_dims(alpha_weight, 3) *
                                            head_tail,
                                            axis=2)
        return triples_embedding, entities_word_embedding, graph_embedding

    def _build_kg_graph(self):
        encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts))
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        # knowledge graph vectors
        graph_embed_input = tf.gather_nd(
            self.graph_embedding,
            tf.concat([
                tf.tile(
                    tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32),
                               [-1, 1, 1]), [1, encoder_len, 1]),
                self.posts_triple
            ],
                      axis=2))

        # knowledge triple vectors
        triple_embed_input = tf.reshape(
            tf.nn.embedding_lookup(
                self.entity_embed,
                self.entity2index.lookup(self.responses_triple)),
            [batch_size, decoder_len, 3 * self.num_trans_units])

        return graph_embed_input, triple_embed_input

    def _build_encoder(self, graph_embed_input):
        post_word_input = tf.nn.embedding_lookup(
            self.word_embed, self.posts_word_id)  # batch*len*unit
        encoder_cell = MultiRNNCell(
            [GRUCell(self.num_units) for _ in range(self.num_layers)])

        # encoder input: e(x_t) = [w(x_t); g_i]
        encoder_input = tf.concat([post_word_input, graph_embed_input], axis=2)
        encoder_output, encoder_state = dynamic_rnn(encoder_cell,
                                                    encoder_input,
                                                    self.posts_length,
                                                    dtype=tf.float32,
                                                    scope="encoder")
        # shape:[batch_size, max_time, cell.output_size]
        return encoder_output, encoder_state

    def _build_decoder(self, encoder_output, encoder_state,
                       triple_embed_input):
        # decoder input: e(y_t) = [w(y_t); k_j]
        encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts))
        response_word_input = tf.nn.embedding_lookup(
            self.word_embed, self.responses_word_id)  # batch*len*unit
        decoder_input = tf.concat([response_word_input, triple_embed_input],
                                  axis=2)
        print("decoder_input:", decoder_input.shape)

        # define cell
        decoder_cell = MultiRNNCell(
            [GRUCell(self.num_units) for _ in range(self.num_layers)])

        # get loss functions
        sequence_loss, total_loss = loss_computation(
            self.vocab_size, num_samples=self.num_samples)

        # decoder training process
        with tf.variable_scope('decoder'):
            # prepare attention
            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                = prepare_attention(encoder_output, 'bahdanau', self.num_units, scope_name="decoder",
                                    imem=(self.graph_embedding, self.triples_embedding),
                                    output_alignments=self.output_alignments)
            print("graph_embedding:", self.graph_embedding.shape)
            print("triples_embedding:", self.triples_embedding.shape)
            decoder_fn_train = attention_decoder_fn_train(
                encoder_state,
                attention_keys,
                attention_values,
                attention_score_fn,
                attention_construct_fn,
                output_alignments=self.output_alignments,
                max_length=tf.reduce_max(self.responses_length))
            # train decoder
            decoder_output, _, decoder_context_state = dynamic_rnn_decoder(
                decoder_cell,
                decoder_fn_train,
                decoder_input,
                self.responses_length,
                scope="decoder_rnn")
            output_fn, selector_fn = output_projection(
                self.vocab_size, scope_name="decoder_rnn")
            output_logits = output_fn(decoder_output)
            selector_logits = selector_fn(decoder_output)
            print("decoder_output:",
                  decoder_output.shape)  # shape: [batch, seq, num_units]
            print("output_logits:", output_logits.shape)
            print("selector_fn:", selector_logits.name)

            triple_len = tf.shape(self.triples)[2]
            one_hot_triples = tf.one_hot(self.match_triples, triple_len)
            use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3])
            alignments = tf.transpose(decoder_context_state.stack(),
                                      perm=[1, 0, 2, 3])
            self.decoder_loss, self.ppx_loss, self.sentence_ppx \
                = total_loss(output_logits,
                             selector_logits,
                             self.responses_target,
                             self.decoder_mask,
                             alignments,
                             use_triples,
                             one_hot_triples)
            self.sentence_ppx = tf.identity(self.sentence_ppx, name="ppx_loss")

        # decoder inference process
        with tf.variable_scope('decoder', reuse=True):
            # prepare attention
            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                = prepare_attention(encoder_output, 'bahdanau', self.num_units, scope_name="decoder",
                                    imem=(self.graph_embedding, self.triples_embedding),
                                    output_alignments=self.output_alignments,
                                    reuse=True)
            output_fn, selector_fn = output_projection(self.vocab_size,
                                                       scope_name=None,
                                                       reuse=True)
            decoder_fn_inference \
                = attention_decoder_fn_inference(output_fn, encoder_state,
                                                 attention_keys, attention_values,
                                                 attention_score_fn, attention_construct_fn,
                                                 self.word_embed, GO_ID, EOS_ID, self.max_length, self.vocab_size,
                                                 imem=(self.entities_word_embedding,
                                                       tf.reshape(self.triples_embedding,
                                                                  [encoder_batch_size, -1, 3 * self.num_trans_units])),
                                                 selector_fn=selector_fn)

            # get decoder output
            decoder_distribution, _, infer_context_state \
                = dynamic_rnn_decoder(decoder_cell, decoder_fn_inference, scope="decoder_rnn")

            output_len = tf.shape(decoder_distribution)[1]
            output_ids = tf.transpose(
                infer_context_state.gather(tf.range(output_len)))
            word_ids = tf.cast(
                tf.clip_by_value(output_ids, 0, self.vocab_size), tf.int64)
            entity_ids = tf.reshape(
                tf.clip_by_value(-output_ids, 0, self.vocab_size) + tf.reshape(
                    tf.range(encoder_batch_size) *
                    tf.shape(self.entities_word_embedding)[1], [-1, 1]), [-1])
            entities = tf.reshape(
                tf.gather(tf.reshape(self.entities, [-1]), entity_ids),
                [-1, output_len])
            words = self.index2symbol.lookup(word_ids)
            self.generation = tf.where(output_ids > 0, words, entities)
            self.generation = tf.identity(self.generation, name='generation')

    def set_vocabs(self, session, vocab, entity_vocab, relation_vocab):
        op_in = self.symbol2index.insert(
            constant_op.constant(vocab),
            constant_op.constant(list(range(self.vocab_size)), dtype=tf.int64))
        session.run(op_in)
        op_out = self.index2symbol.insert(
            constant_op.constant(list(range(self.vocab_size)), dtype=tf.int64),
            constant_op.constant(vocab))
        session.run(op_out)
        op_in = self.entity2index.insert(
            constant_op.constant(entity_vocab + relation_vocab),
            constant_op.constant(list(
                range(len(entity_vocab) + len(relation_vocab))),
                                 dtype=tf.int64))
        session.run(op_in)
        op_out = self.index2entity.insert(
            constant_op.constant(list(
                range(len(entity_vocab) + len(relation_vocab))),
                                 dtype=tf.int64),
            constant_op.constant(entity_vocab + relation_vocab))
        session.run(op_out)
        return session

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape().as_list()))

    def step_train(self, session, data, forward_only=False, summary=False):
        input_feed = {
            self.posts: data['posts'],
            self.posts_length: data['posts_length'],
            self.responses: data['responses'],
            self.responses_length: data['responses_length'],
            self.triples: data['triples'],
            self.posts_triple: data['posts_triple'],
            self.responses_triple: data['responses_triple'],
            self.match_triples: data['match_triples']
        }
        if forward_only:
            output_feed = [self.sentence_ppx]
        else:
            output_feed = [self.sentence_ppx, self.decoder_loss, self.update]
        if summary:
            output_feed.append(self.merged_summary_op)

        return session.run(output_feed, input_feed)