Ejemplo n.º 1
0
    def __init__(self, data, args, embed):
        self.init_states = tf.placeholder(tf.float32, (None, args.ch_size),
                                          'ctx_inps')  # batch*ch_size
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.prev_posts = tf.placeholder(tf.int32, (None, None),
                                         'enc_prev_inps')
        self.prev_posts_length = tf.placeholder(tf.int32, (None, ),
                                                'enc_prev_lens')

        self.kgs = tf.placeholder(tf.int32, (None, None, None),
                                  'kg_inps')  # batch*len
        self.kgs_h_length = tf.placeholder(tf.int32, (None, None),
                                           'kg_h_lens')  # batch
        self.kgs_hr_length = tf.placeholder(tf.int32, (None, None),
                                            'kg_hr_lens')  # batch
        self.kgs_hrt_length = tf.placeholder(tf.int32, (None, None),
                                             'kg_hrt_lens')  # batch
        self.kgs_index = tf.placeholder(tf.float32, (None, None),
                                        'kg_indices')  # batch

        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        self.context_length = tf.placeholder(tf.int32, (None, ), 'ctx_lens')
        self.is_train = tf.placeholder(tf.bool)

        num_past_turns = tf.shape(self.posts)[0] // tf.shape(
            self.origin_responses)[0]

        # deal with original data to adapt encoder and decoder
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        self.posts_input = self.posts  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])
        kg_len = tf.shape(self.kgs)[2]
        kg_h_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_h_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_hr_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_hr_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_hrt_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_hrt_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_key_mask = kg_hr_mask
        kg_value_mask = kg_hrt_mask - kg_hr_mask

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        self.kg_input = tf.nn.embedding_lookup(self.embed, self.kgs)
        #self.knowledge_max = tf.reduce_max(tf.where(tf.cast(tf.tile(knowledge_mask, [1, 1, args.embedding_size]), tf.bool), self.knowledge_input, -mask_value), axis=1)
        #self.knowledge_min = tf.reduce_max(tf.where(tf.cast(tf.tile(knowledge_mask, [1, 1, args.embedding_size]), tf.bool), self.knowledge_input, mask_value), axis=1)
        self.kg_key_avg = tf.reduce_sum(
            self.kg_input * kg_key_mask, axis=2) / tf.maximum(
                tf.reduce_sum(kg_key_mask, axis=2),
                tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1),
                             dtype=tf.float32))
        self.kg_value_avg = tf.reduce_sum(
            self.kg_input * kg_value_mask, axis=2) / tf.maximum(
                tf.reduce_sum(kg_value_mask, axis=2),
                tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1),
                             dtype=tf.float32))

        #self.encoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.posts_input))  # batch*len*unit
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            encoder_output, encoder_state = dynamic_rnn(cell_enc,
                                                        self.encoder_input,
                                                        self.posts_length,
                                                        dtype=tf.float32,
                                                        scope="encoder_rnn")

        with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
            prev_output, _ = dynamic_rnn(cell_enc,
                                         tf.nn.embedding_lookup(
                                             self.embed, self.prev_posts),
                                         self.prev_posts_length,
                                         dtype=tf.float32,
                                         scope="encoder_rnn")

        with tf.variable_scope('context'):
            encoder_state_reshape = tf.reshape(
                encoder_state, [-1, num_past_turns, args.eh_size])
            _, self.context_state = dynamic_rnn(cell_ctx,
                                                encoder_state_reshape,
                                                self.context_length,
                                                dtype=tf.float32,
                                                scope='context_rnn')

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # construct attention
        '''
		encoder_len = tf.shape(encoder_output)[1]
		attention_memory = tf.reshape(encoder_output, [batch_size, -1, args.eh_size])
		attention_mask = tf.reshape(tf.sequence_mask(self.posts_length, encoder_len), [batch_size, -1])
		attention_mask = tf.concat([tf.ones([batch_size, 1], tf.bool), attention_mask[:,1:]], axis=1)
		attn_mechanism = MyAttention(args.dh_size, attention_memory, attention_mask)
		'''
        attn_mechanism = tf.contrib.seq2seq.BahdanauAttention(
            args.dh_size,
            prev_output,
            memory_sequence_length=tf.maximum(self.prev_posts_length, 1))
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        ctx_state_shaping = tf.layers.dense(self.context_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=ctx_state_shaping)

        # calculate kg embedding
        with tf.variable_scope('knowledge'):
            query = tf.reshape(
                tf.layers.dense(tf.concat(self.context_state, axis=-1),
                                args.embedding_size,
                                use_bias=False),
                [batch_size, 1, args.embedding_size])
        kg_score = tf.reduce_sum(query * self.kg_key_avg, axis=2)
        kg_score = tf.where(tf.greater(self.kgs_hrt_length, 0), kg_score,
                            -tf.ones_like(kg_score) * np.inf)
        kg_alignment = tf.nn.softmax(kg_score)
        kg_max = tf.argmax(kg_alignment, axis=-1)
        kg_max_onehot = tf.one_hot(kg_max,
                                   tf.shape(kg_alignment)[1],
                                   dtype=tf.float32)
        self.kg_acc = tf.reduce_sum(
            kg_max_onehot * self.kgs_index) / tf.maximum(
                tf.reduce_sum(tf.reduce_max(self.kgs_index, axis=-1)),
                tf.constant(1.0))
        self.kg_loss = tf.reduce_sum(
            -tf.log(tf.clip_by_value(kg_alignment, 1e-12, 1.0)) *
            self.kgs_index,
            axis=1) / tf.maximum(tf.reduce_sum(self.kgs_index, axis=1),
                                 tf.ones([batch_size], dtype=tf.float32))
        self.kg_loss = tf.reduce_mean(self.kg_loss)

        self.knowledge_embed = tf.reduce_sum(
            tf.expand_dims(kg_alignment, axis=-1) * self.kg_value_avg *
            tf.cast(kg_num_mask, tf.float32),
            axis=1)
        #self.knowledge_embed = tf.Print(self.knowledge_embed, ['acc', self.kg_acc, 'loss', self.kg_loss])
        knowledge_embed_extend = tf.tile(
            tf.expand_dims(self.knowledge_embed, axis=1), [1, decoder_len, 1])
        self.decoder_input = tf.concat(
            [self.decoder_input, knowledge_embed_extend], axis=2)
        # construct helper
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, tf.maximum(self.responses_length, 1))
        infer_helper = MyInferenceHelper(self.embed,
                                         tf.fill([batch_size], data.go_id),
                                         data.eos_id, self.knowledge_embed)
        #infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embed, tf.fill([batch_size], data.go_id), data.eos_id)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            self.decoder_output = train_outputs.rnn_output
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_sent_length,
                scope="decoder_rnn")
            self.decoder_distribution = infer_outputs.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        self.loss = self.decoder_loss + self.kg_loss
        gradients = tf.gradients(self.loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)
Ejemplo n.º 2
0
    def __init__(self, data, args, embed):

        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        self.is_train = tf.placeholder(tf.bool)

        # deal with original data to adapt encoder and decoder
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1], 1)[0]
        self.responses_target = tf.split(self.origin_responses,
                                         [1, decoder_len - 1], 1)[1]
        self.responses_length = self.origin_responses_length - 1
        decoder_len = decoder_len - 1

        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell = tf.nn.rnn_cell.GRUCell(args.eh_size)

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # build encoder
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            _, self.encoder_state = dynamic_rnn(cell,
                                                self.encoder_input,
                                                self.posts_length,
                                                dtype=tf.float32,
                                                scope="decoder_rnn")

        # construct helper and attention
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.eos_id), data.eos_id)

        dec_start = tf.cond(
            self.is_train,
            lambda: tf.zeros([batch_size, args.dh_size], dtype=tf.float32),
            lambda: self.encoder_state)

        # build decoder (train)
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            self.decoder_output, _ = dynamic_rnn(
                cell,
                self.decoder_input,
                self.responses_length,
                dtype=tf.float32,
                initial_state=self.encoder_state,
                scope='decoder_rnn')
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_sent_length,
                scope="decoder_rnn")
            self.decoder_distribution = infer_outputs.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)
Ejemplo n.º 3
0
    def __init__(self, data, args, embed):
        # 定义输入的占位符
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        self.is_train = tf.placeholder(tf.bool)

        # deal with original data to adapt encoder and decoder
        # 记录batch_size以及response的长度
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        # 定义输入和输出,对应的长度都是decoder_len-1
        # input是去除最后一个位置的token,target是去除第一个位置的token
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1], 1)[0]
        self.responses_target = tf.split(self.origin_responses,
                                         [1, decoder_len - 1], 1)[1]
        self.responses_length = self.origin_responses_length - 1
        decoder_len = decoder_len - 1

        # 定义decoder的mask矩阵,[batch_size, decoder_len]
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        # 初始化训练参数,学习率等
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        # 定义词向量矩阵
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        # 定义GRU单元,这里的eh_size表示隐层的单元数
        cell = tf.nn.rnn_cell.GRUCell(args.eh_size)

        # get output projection function
        # 映射到词表分布上
        output_fn = MyDense(data.vocab_size, use_bias=True)
        # dh_size表示解码器的隐层输出
        # vocab_size表示词表的大小
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # build encoder
        # 这里主要用于测试阶段,通过对post编码得到隐向量
        # 如果是训练阶段,这里应该是全0的矩阵
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            _, self.encoder_state = tf.nn.dynamic_rnn(cell,
                                                      self.encoder_input,
                                                      self.posts_length,
                                                      dtype=tf.float32,
                                                      scope="decoder_rnn")

        # construct helper and attention
        # 这里之所以填充eos_id,是因为在解码时,decoder的输入第一个字符就是<eos>
        # 这在对输入数据进行封装的时候就定义好了
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.eos_id), data.eos_id)

        dec_start = tf.cond(
            self.is_train,
            lambda: tf.zeros([batch_size, args.dh_size], dtype=tf.float32),
            lambda: self.encoder_state)

        # build decoder (train)
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            # 这里的Output是[batch_size, decoder_length, hidden_size]
            self.decoder_output, _ = tf.nn.dynamic_rnn(cell,
                                                       self.decoder_input,
                                                       self.responses_length,
                                                       dtype=tf.float32,
                                                       initial_state=dec_start,
                                                       scope='decoder_rnn')
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            # decoder_output的维度是[batch_size, decoder_length, hidden_size]
            # responses_target的维度是[batch_size, decoder_length]
            # decoder_mask的维度是[batch_size, decoder_mask]
            # 输出decoder_distribution_teacher的维度是[batch_size, decoder_len, vocab_size]
            # 输出的decoder_loss是0维的平均损失
            # 输出的decoder_all_loss是每一个样例一句话的损失[batch_size,]
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_decoder_length,
                scope="decoder_rnn")
            # [batch_size, max_decoder_length, vocab_size]
            self.decoder_distribution = infer_outputs.rnn_output
            # 得到生成的单词的索引[batch_size, max_decoder_length]
            # 这里由于前4个单词是["<pad>", "<unk>", "<go>", "<eos>"]
            # 所以要去除<pad>和<unk>这两个单词
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        # 记录参数
        self.create_summary(args)
Ejemplo n.º 4
0
    def __init__(self, data, args, embed):
        # posts表示编码器,即历史对话输入 [batch, encoder_len]
        # posts_length表示输入的每一句话的实际长度 [batch]
        # prev_length除去最后一轮,之前轮次语句的长度(包含<go>和<eos>),[batch]
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.prevs_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens_prev')  # batch

        # origin_responses表示回复的内容,[batch, resp_len]
        # origin_responses_length表示每一个回复的实际长度,[batch, ]
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        self.is_train = tf.placeholder(tf.bool)

        # deal with original data to adapt encoder and decoder
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        # 这里对回复进行分割,此时祛除了回复中的go_id
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        # 这里得到解码器的输入和输出,输入去除了最后的eos_id,输出去除了最开始的go_id,这样保证对齐
        # [batch, decoder_len](这里的decoder_len等于resp_len-1)
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        # 编码器输入 [batch, encoder_len]
        self.posts_input = self.posts  # batch*len
        # 这里计算decoder的mask矩阵
        # 等于[batch, decoder_len]
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        # 将编码器和解码器的输入转化为词向量
        # encoder_input: [batch, encoder_len, embed_size]
        # decoder_input: [batch, decoder_len, embed_size]
        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        #self.encoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.posts_input)) #batch*len*unit
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            # encoder_output: [batch, encoder_len, eh_size]
            # encoder_state: [batch, eh_size]
            encoder_output, encoder_state = tf.nn.dynamic_rnn(
                cell_enc,
                self.encoder_input,
                self.posts_length,
                dtype=tf.float32,
                scope="encoder_rnn")

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        encoder_len = tf.shape(encoder_output)[1]
        # 这里计算posts和prevs的mask矩阵
        posts_mask = tf.sequence_mask(self.posts_length, encoder_len)
        prevs_mask = tf.sequence_mask(self.prevs_length, encoder_len)
        # 不同为1,相同为1
        # 这里表示只关注最后一轮,[batch, encoder_len]
        attention_mask = tf.reshape(tf.logical_xor(posts_mask, prevs_mask),
                                    [batch_size, encoder_len])

        # construct helper and attention
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, self.responses_length)
        # 这里在推理的时候,起始位置全部使用go_id进行填充
        # 这在对输入数据进行封装时即进行了定义
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.go_id), data.eos_id)

        # 这里编码器是按照多轮输入进行编码的
        # 但是解码器在attention的时候只关注最后一轮输入
        # 这里定义输入输出attention
        attn_mechanism = MyAttention(args.dh_size, encoder_output,
                                     attention_mask)
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        # 把编码器最后一层的隐状态映射到解码器隐状态的维度
        # [batch, dh_size]
        enc_state_shaping = tf.layers.dense(encoder_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=enc_state_shaping)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            self.decoder_output = train_outputs.rnn_output
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            # 计算损失和概率分布
            # decoder_distribution_teacher:[batch, decoder_length, vocab_size] (这里都是对数概率)
            # decoder_loss,基于这个batch中所有词的损失,0维
            # decoder_all_loss,每一句话的损失,[batch, ]
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            # 这里output_fn会重用上面的权重和偏置
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_decoder_length,
                scope="decoder_rnn")
            # [batch, max_decoder_len, vocab_size]
            self.decoder_distribution = infer_outputs.rnn_output
            # 这里在计算索引概率最大值的去除前面两个<pad>和<unk>的影响
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)  # 定义优化器
        gradients = tf.gradients(self.decoder_loss, self.params)  # 计算参数的梯度
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)  # 梯度裁剪
        self.update = opt.apply_gradients(
            zip(clipped_gradients,
                self.params), global_step=self.global_step)  # 对参数进行更新

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)
Ejemplo n.º 5
0
    def __init__(self, data, args, embed):
        #self.init_states = tf.placeholder(tf.float32, (None, args.ch_size), 'ctx_inps')  # batch*ch_size
        # posts: [batch*(num_turns-1), max_post_length]
        # posts_length: [batch*(num_turns-1),]
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch * num_turns-1 * len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch * num_turns-1

        # prev_posts: [batch, max_prev_length],即对应上面post的最后一轮
        # prev_posts_length: [batch],即最后一轮每句话的实际长度
        self.prev_posts = tf.placeholder(tf.int32, (None, None),
                                         'enc_prev_inps')
        self.prev_posts_length = tf.placeholder(tf.int32, (None, ),
                                                'enc_prev_lens')

        # origin_responses: [batch, max_response_length]
        # origin_responses_length: [batch]
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        # context_length: [batch],表示每一个posts的实际轮次
        self.context_length = tf.placeholder(tf.int32, (None, ), 'ctx_lens')
        self.is_train = tf.placeholder(tf.bool)

        # 即对应num_turns-1(也有可能比这个小)
        # 表示当前batch的实际最大轮次
        num_past_turns = tf.shape(self.posts)[0] // tf.shape(
            self.origin_responses)[0]

        # deal with original data to adapt encoder and decoder
        # 获取解码器的输入和输出
        # 其中输入没有最后的<eos>,输出没有最开始的<go>
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        self.posts_input = self.posts  # batch*len
        # [batch, decoder_length]
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        # self.encoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.posts_input))  # batch*len*unit
        # self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            # encoder_output: [batch*(num_turns-1), max_post_length, eh_size]
            # encoder_state: [batch*(num_turns-1), eh_size]
            encoder_output, encoder_state = tf.nn.dynamic_rnn(
                cell_enc,
                self.encoder_input,
                self.posts_length,
                dtype=tf.float32,
                scope="encoder_rnn")

        with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
            # prev_output: [batch, max_prev_length, eh_size]
            prev_output, _ = tf.nn.dynamic_rnn(cell_enc,
                                               tf.nn.embedding_lookup(
                                                   self.embed,
                                                   self.prev_posts),
                                               self.prev_posts_length,
                                               dtype=tf.float32,
                                               scope="encoder_rnn")

        # encoder_hidden_size = tf.shape(encoder_state)[-1]

        with tf.variable_scope('context'):
            # encoder_state_reshape: [batch, num_turns-1, eh_size]
            # context_output: [batch, num_turns-1, ch_size]
            # context_state: [batch, ch_size]
            encoder_state_reshape = tf.reshape(
                encoder_state, [-1, num_past_turns, args.eh_size])
            context_output, self.context_state = tf.nn.dynamic_rnn(
                cell_ctx,
                encoder_state_reshape,
                self.context_length,
                dtype=tf.float32,
                scope='context_rnn')

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # construct helper and attention
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, tf.maximum(self.responses_length, 1))
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.go_id), data.eos_id)

        #encoder_len = tf.shape(encoder_output)[1]
        #attention_memory = tf.reshape(encoder_output, [batch_size, -1, args.eh_size])
        #attention_mask = tf.reshape(tf.sequence_mask(self.posts_length, encoder_len), [batch_size, -1])
        '''
        attention_memory = context_output
        attention_mask = tf.reshape(tf.sequence_mask(self.context_length, self.num_turns - 1), [batch_size, -1])
        '''
        #attention_mask = tf.concat([tf.ones([batch_size, 1], tf.bool), attention_mask[:, 1:]], axis=1)
        #attn_mechanism = MyAttention(args.dh_size, attention_memory, attention_mask)
        # 注意这里的inputs,是最后一句话的编码,即[batch_size, prev_post_length, eh_size]
        # 在attention中,如果query的维度和inputs不一致,需要先经过线性层将query转化为
        attn_mechanism = tf.contrib.seq2seq.BahdanauAttention(
            args.dh_size,
            prev_output,
            memory_sequence_length=tf.maximum(self.prev_posts_length, 1))
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        # 将posts的编码输出转化为解码器的维度 [batch, dh_size]
        ctx_state_shaping = tf.layers.dense(self.context_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=ctx_state_shaping)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            # 这里的decoder_output: [batch, decoder_length, dh_size]
            self.decoder_output = train_outputs.rnn_output
            # self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            # decoder_distribution_teacher: [batch, decoder_length, vocab_size]
            # decoder_loss: 标量
            # decoder_all_loss: [batch, ],表示每一句话的对数损失
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_decoder_length,
                scope="decoder_rnn")
            # [batch, max_decoder_length, vocab_size]
            self.decoder_distribution = infer_outputs.rnn_output
            # 得到每一步解码的单词索引[batch, max_decoder_length]
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)
Ejemplo n.º 6
0
    def __init__(self, data, args, embed):
        # 这里的输入和前面的seq2seq一致
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.prevs_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens_prev')  # batch
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch

        # kgs表示该样例所在这段对话中所有的知识:[batch, max_kg_nums, max_kg_length]
        # kgs_h_length表示每一个知识中head entity的长度:[batch, max_kg_nums]
        # kgs_hr_length表示每一个知识中head entity和relation的长度:[batch, max_kg_nums]
        # kgs_hrt_length表示每一个知识中h,r,t的长度:[batch, max_kg_nums]
        # kgs_index表示当前这句话实际使用的kg的索引指示矩阵:[batch, max_kg_nums](其中使用的知识对应为1,没有使用的知识对应为0)
        self.kgs = tf.placeholder(tf.int32, (None, None, None), 'kg_inps')
        self.kgs_h_length = tf.placeholder(tf.int32, (None, None), 'kg_h_lens')
        self.kgs_hr_length = tf.placeholder(tf.int32, (None, None),
                                            'kg_hr_lens')
        self.kgs_hrt_length = tf.placeholder(tf.int32, (None, None),
                                             'kg_hrt_lens')
        self.kgs_index = tf.placeholder(tf.float32, (None, None), 'kg_indices')

        # 用来平衡解码损失和kg损失的超参数
        self.lamb = tf.placeholder(tf.float32, name='lamb')
        self.is_train = tf.placeholder(tf.bool)

        # deal with original data to adapt encoder and decoder
        # 获取解码器的输入和输出
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        # 获取编码器的输入
        self.posts_input = self.posts  # batch*len
        # 对解码器的mask矩阵,对于pad的mask
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])
        kg_len = tf.shape(self.kgs)[2]
        #kg_len = tf.Print(kg_len, [batch_size, kg_len, decoder_len, self.kgs_length])
        # kg_h_mask = tf.reshape(tf.cumsum(tf.one_hot(self.kgs_h_length-1,
        # 	kg_len), reverse=True, axis=2), [batch_size, -1, kg_len, 1])
        # 这里分别得到对于key(也就是hr)的mask矩阵:[batch_size, max_kg_nums, max_kg_length, 1]
        # 以及对于value(也就是t)的mask矩阵:[batch_size, max_kg_nums, max_kg_length, 1]
        kg_hr_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_hr_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_hrt_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_hrt_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_key_mask = kg_hr_mask
        kg_value_mask = kg_hrt_mask - kg_hr_mask

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)
        # encoder_input: [batch, encoder_len, embed_size]
        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        # decoder_input: [batch, decoder_len, embed_size]
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        # kg_input: [batch, max_kg_nums, max_kg_length, embed_size]
        self.kg_input = tf.nn.embedding_lookup(self.embed, self.kgs)
        #self.encoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.posts_input)) #batch*len*unit
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            encoder_output, encoder_state = tf.nn.dynamic_rnn(
                cell_enc,
                self.encoder_input,
                self.posts_length,
                dtype=tf.float32,
                scope="encoder_rnn")
        # key对应一个知识h,r的词向量的均值 [batch, max_kg_nums, embed_size]
        # value对应一个知识t的词向量的均值 [batch, max_kg_nums, embed_size]
        self.kg_key_avg = tf.reduce_sum(
            self.kg_input * kg_key_mask, axis=2) / tf.maximum(
                tf.reduce_sum(kg_key_mask, axis=2),
                tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1),
                             dtype=tf.float32))
        self.kg_value_avg = tf.reduce_sum(
            self.kg_input * kg_value_mask, axis=2) / tf.maximum(
                tf.reduce_sum(kg_value_mask, axis=2),
                tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1),
                             dtype=tf.float32))
        # 将编码器的输出状态映射到embed_size的维度
        # query: [batch, 1, embed_size]
        with tf.variable_scope('knowledge'):
            query = tf.reshape(
                tf.layers.dense(tf.concat(encoder_state, axis=-1),
                                args.embedding_size,
                                use_bias=False),
                [batch_size, 1, args.embedding_size])
        # [batch, max_kg_nums]
        kg_score = tf.reduce_sum(query * self.kg_key_avg, axis=2)
        # 对于hrt大于0的位置(即该位置存在知识),取对应的kg_score,否则对应位置为-inf
        kg_score = tf.where(tf.greater(self.kgs_hrt_length, 0), kg_score,
                            -tf.ones_like(kg_score) * np.inf)
        # 计算每个知识对应的分数 [batch, max_kg_nums]
        kg_alignment = tf.nn.softmax(kg_score)

        # 根据计算的kg注意力分数的位置,计算关注的kg准确率和损失
        kg_max = tf.argmax(kg_alignment, axis=-1)
        kg_max_onehot = tf.one_hot(kg_max,
                                   tf.shape(kg_alignment)[1],
                                   dtype=tf.float32)
        self.kg_acc = tf.reduce_sum(
            kg_max_onehot * self.kgs_index) / tf.maximum(
                tf.reduce_sum(tf.reduce_max(self.kgs_index, axis=-1)),
                tf.constant(1.0))
        self.kg_loss = tf.reduce_sum(
            -tf.log(tf.clip_by_value(kg_alignment, 1e-12, 1.0)) *
            self.kgs_index,
            axis=1) / tf.maximum(tf.reduce_sum(self.kgs_index, axis=1),
                                 tf.ones([batch_size], dtype=tf.float32))
        self.kg_loss = tf.reduce_mean(self.kg_loss)
        # 得到注意力之后的知识的嵌入:[batch, embed_size]
        self.knowledge_embed = tf.reduce_sum(
            tf.expand_dims(kg_alignment, axis=-1) * self.kg_value_avg, axis=1)
        # 对维度进行扩充[batch, decoder_len, embed_size]
        knowledge_embed_extend = tf.tile(
            tf.expand_dims(self.knowledge_embed, axis=1), [1, decoder_len, 1])
        # 将知识和原始的解码输入拼接,作为新的解码输入 [batch, decoder_len, 2*embed_size]
        self.decoder_input = tf.concat(
            [self.decoder_input, knowledge_embed_extend], axis=2)

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        encoder_len = tf.shape(encoder_output)[1]
        posts_mask = tf.sequence_mask(self.posts_length, encoder_len)
        prevs_mask = tf.sequence_mask(self.prevs_length, encoder_len)
        attention_mask = tf.reshape(tf.logical_xor(posts_mask, prevs_mask),
                                    [batch_size, encoder_len])

        # construct helper and attention
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, self.responses_length)
        #infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embed, tf.fill([batch_size], data.go_id), data.eos_id)
        # 为了在推理的时候,每一次的输入都是上一次输出和知识的拼接
        infer_helper = MyInferenceHelper(self.embed,
                                         tf.fill([batch_size], data.go_id),
                                         data.eos_id, self.knowledge_embed)
        #attn_mechanism = tf.contrib.seq2seq.BahdanauAttention(args.dh_size, encoder_output,
        #  memory_sequence_length=self.posts_length)
        # 这里的MyAttention主要解决BahdanauAttention只能输入编码序列长度的问题
        attn_mechanism = MyAttention(args.dh_size, encoder_output,
                                     attention_mask)
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        enc_state_shaping = tf.layers.dense(encoder_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=enc_state_shaping)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            self.decoder_output = train_outputs.rnn_output
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            # 输出概率分布和解码损失
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_decoder_length,
                scope="decoder_rnn")
            # 输出解码概率分布
            self.decoder_distribution = infer_outputs.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        # 将解码损失和kg损失相加
        self.loss = self.decoder_loss + self.lamb * self.kg_loss
        gradients = tf.gradients(self.loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)