コード例 #1
0
ファイル: model.py プロジェクト: zyjcs/ccm
    def __init__(self,
            num_symbols,
            num_embed_units,
            num_units,
            num_layers,
            embed,
            entity_embed=None,
            num_entities=0,
            num_trans_units=100,
            learning_rate=0.0001,
            learning_rate_decay_factor=0.95,
            max_gradient_norm=5.0,
            num_samples=512,
            max_length=60,
            output_alignments=True,
            use_lstm=False):
        
        self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens')  # batch
        self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps')  # batch*len
        self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens')  # batch
        self.entities = tf.placeholder(tf.string, (None, None), 'entities')  # batch
        self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks')  # batch
        self.triples = tf.placeholder(tf.string, (None, None, 3), 'triples')  # batch
        self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples')  # batch
        self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples')  # batch
        self.match_triples = tf.placeholder(tf.int32, (None, None), 'match_triples')  # batch
        encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts))
        triple_num = tf.shape(self.triples)[1]
        
        #use_triples = tf.reduce_sum(tf.cast(tf.greater_equal(self.match_triples, 0), tf.float32), axis=-1)
        one_hot_triples = tf.one_hot(self.match_triples, triple_num)
        use_triples = tf.reduce_sum(one_hot_triples, axis=[2])

        self.symbol2index = MutableHashTable(
                key_dtype=tf.string,
                value_dtype=tf.int64,
                default_value=UNK_ID,
                shared_name="in_table",
                name="in_table",
                checkpoint=True)
        self.index2symbol = MutableHashTable(
                key_dtype=tf.int64,
                value_dtype=tf.string,
                default_value='_UNK',
                shared_name="out_table",
                name="out_table",
                checkpoint=True)
        self.entity2index = MutableHashTable(
                key_dtype=tf.string,
                value_dtype=tf.int64,
                default_value=NONE_ID,
                shared_name="entity_in_table",
                name="entity_in_table",
                checkpoint=True)
        self.index2entity = MutableHashTable(
                key_dtype=tf.int64,
                value_dtype=tf.string,
                default_value='_NONE',
                shared_name="entity_out_table",
                name="entity_out_table",
                checkpoint=True)
        # build the vocab table (string to index)


        self.posts_word_id = self.symbol2index.lookup(self.posts)   # batch*len
        self.posts_entity_id = self.entity2index.lookup(self.posts)   # batch*len
        #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6)
        self.responses_target = self.symbol2index.lookup(self.responses)   #batch*len
        
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1]
        self.responses_word_id = tf.concat([tf.ones([batch_size, 1], dtype=tf.int64)*GO_ID,
            tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1)   # batch*len
        self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, 
            decoder_len), reverse=True, axis=1), [-1, decoder_len])
        
        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed)
        if entity_embed is None:
            # initialize the embedding randomly
            self.entity_trans = tf.get_variable('entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False)
        else:
            # initialize the embedding by pre-trained word vectors
            self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False)

        self.entity_trans_transformed = tf.layers.dense(self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation')
        padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer())

        self.entity_embed = tf.concat([padding_entity, self.entity_trans_transformed], axis=0)

        triples_embedding = tf.reshape(tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, 3 * num_trans_units])
        entities_word_embedding = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units])


        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_word_id) #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_word_id) #batch*len*unit

        encoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)])
        decoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)])
        
        # rnn encoder
        encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, 
                self.posts_length, dtype=tf.float32, scope="encoder")

        # get output projection function
        output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer(num_units, 
                num_symbols, num_samples)

        

        with tf.variable_scope('decoder'):
            # get attention function
            attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \
                    = prepare_attention(encoder_output, 'bahdanau', num_units, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units)

            decoder_fn_train = attention_decoder_fn_train(
                    encoder_state, attention_keys_init, attention_values_init,
                    attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments, max_length=tf.reduce_max(self.responses_length))
            self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(decoder_cell, decoder_fn_train, 
                    self.decoder_input, self.responses_length, scope="decoder_rnn")
            if output_alignments: 
                self.alignments = tf.transpose(alignments_ta.stack(), perm=[1,0,2])
                #self.alignments = tf.Print(self.alignments, [self.alignments], summarize=1e8)
                self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples)
                self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss')
                #self.decoder_loss = tf.Print(self.decoder_loss, ['decoder_loss', self.decoder_loss], summarize=1e6)
            else:
                self.decoder_loss, self.sentence_ppx = sequence_loss(self.decoder_output, 
                        self.responses_target, self.decoder_mask)
                self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss')
         
        with tf.variable_scope('decoder', reuse=True):
            # get attention function
            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                    = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units)
            decoder_fn_inference = attention_decoder_fn_inference(
                    output_fn, encoder_state, attention_keys, attention_values, 
                    attention_score_fn, attention_construct_fn, self.embed, GO_ID, 
                    EOS_ID, max_length, num_symbols, imem=entities_word_embedding, selector_fn=selector_fn)

                
            self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(decoder_cell,
                    decoder_fn_inference, scope="decoder_rnn")
            if output_alignments:
                output_len = tf.shape(self.decoder_distribution)[1]
                output_ids = tf.transpose(output_ids_ta.gather(tf.range(output_len)))
                word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64)
                entity_ids = tf.reshape(tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1])
                entities = tf.reshape(tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len])
                words = self.index2symbol.lookup(word_ids)
                self.generation = tf.where(output_ids > 0, words, entities, name='generation')
            else:
                self.generation_index = tf.argmax(self.decoder_distribution, 2)
                
                self.generation = self.index2symbol.lookup(self.generation_index, name='generation') 
        

        # initialize the training process
        self.learning_rate = tf.Variable(float(learning_rate), 
                trainable=False, dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
                self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)

        self.params = tf.global_variables()
            
        # calculate the gradient of parameters
        #opt = tf.train.GradientDescentOptimizer(self.learning_rate)
        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        self.lr = opt._lr
       
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 
                max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 
                global_step=self.global_step)

        tf.summary.scalar('decoder_loss', self.decoder_loss)
        for each in tf.trainable_variables():
            tf.summary.histogram(each.name, each)

        self.merged_summary_op = tf.summary.merge_all()
        
        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, 
                max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
        
        self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True)
コード例 #2
0
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 num_layers,
                 is_train,
                 vocab=None,
                 embed=None,
                 learning_rate=0.1,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 num_samples=512,
                 max_length=30,
                 use_lstm=True):

        self.posts_1 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_2 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_3 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_4 = tf.placeholder(tf.string, shape=(None, None))

        self.entity_1 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_2 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_3 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_4 = tf.placeholder(tf.string, shape=(None, None, None, 3))

        self.entity_mask_1 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_2 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_3 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_4 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))

        self.posts_length_1 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_2 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_3 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_4 = tf.placeholder(tf.int32, shape=(None))

        self.responses = tf.placeholder(tf.string, shape=(None, None))
        self.responses_length = tf.placeholder(tf.int32, shape=(None))

        self.epoch = tf.Variable(0, trainable=False, name='epoch')
        self.epoch_add_op = self.epoch.assign(self.epoch + 1)

        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.'] * num_symbols),
                                       name="symbols")

        self.symbol2index = HashTable(KeyValueTensorInitializer(
            self.symbols,
            tf.Variable(
                np.array([i for i in range(num_symbols)], dtype=np.int32),
                False)),
                                      default_value=UNK_ID,
                                      name="symbol2index")

        self.posts_input_1 = self.symbol2index.lookup(self.posts_1)

        self.posts_2_target = self.posts_2_embed = self.symbol2index.lookup(
            self.posts_2)
        self.posts_3_target = self.posts_3_embed = self.symbol2index.lookup(
            self.posts_3)
        self.posts_4_target = self.posts_4_embed = self.symbol2index.lookup(
            self.posts_4)

        self.responses_target = self.symbol2index.lookup(self.responses)

        batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape(
            self.responses)[1]

        self.posts_input_2 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_2_embed, [tf.shape(self.posts_2)[1] - 1, 1],
                     1)[0]
        ], 1)
        self.posts_input_3 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_3_embed, [tf.shape(self.posts_3)[1] - 1, 1],
                     1)[0]
        ], 1)
        self.posts_input_4 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_4_embed, [tf.shape(self.posts_4)[1] - 1, 1],
                     1)[0]
        ], 1)

        self.responses_target = self.symbol2index.lookup(self.responses)

        batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape(
            self.responses)[1]

        self.responses_input = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)

        self.encoder_2_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_2 - 1,
                                 tf.shape(self.posts_2)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_2)[1]])
        self.encoder_3_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_3 - 1,
                                 tf.shape(self.posts_3)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_3)[1]])
        self.encoder_4_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_4 - 1,
                                 tf.shape(self.posts_4)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_4)[1]])

        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        if embed is None:
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input_1 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_1)
        self.encoder_input_2 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_2)
        self.encoder_input_3 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_3)
        self.encoder_input_4 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_4)

        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        entity_embedding_1 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_1)),
            [
                batch_size,
                tf.shape(self.entity_1)[1],
                tf.shape(self.entity_1)[2], 3 * num_embed_units
            ])
        entity_embedding_2 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_2)),
            [
                batch_size,
                tf.shape(self.entity_2)[1],
                tf.shape(self.entity_2)[2], 3 * num_embed_units
            ])
        entity_embedding_3 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_3)),
            [
                batch_size,
                tf.shape(self.entity_3)[1],
                tf.shape(self.entity_3)[2], 3 * num_embed_units
            ])
        entity_embedding_4 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_4)),
            [
                batch_size,
                tf.shape(self.entity_4)[1],
                tf.shape(self.entity_4)[2], 3 * num_embed_units
            ])

        head_1, relation_1, tail_1 = tf.split(entity_embedding_1,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_2, relation_2, tail_2 = tf.split(entity_embedding_2,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_3, relation_3, tail_3 = tf.split(entity_embedding_3,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_4, relation_4, tail_4 = tf.split(entity_embedding_4,
                                              [num_embed_units] * 3,
                                              axis=3)

        with tf.variable_scope('graph_attention'):
            #[batch_size, max_reponse_length, max_triple_num, 2*embed_units]
            head_tail_1 = tf.concat([head_1, tail_1], axis=3)
            #[batch_size, max_reponse_length, max_triple_num, embed_units]
            head_tail_transformed_1 = tf.layers.dense(
                head_tail_1,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            #[batch_size, max_reponse_length, max_triple_num, embed_units]
            relation_transformed_1 = tf.layers.dense(relation_1,
                                                     num_embed_units,
                                                     name='relation_transform')
            #[batch_size, max_reponse_length, max_triple_num]
            e_weight_1 = tf.reduce_sum(relation_transformed_1 *
                                       head_tail_transformed_1,
                                       axis=3)
            #[batch_size, max_reponse_length, max_triple_num]
            alpha_weight_1 = tf.nn.softmax(e_weight_1)
            #[batch_size, max_reponse_length, embed_units]
            graph_embed_1 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_1, 3) *
                (tf.expand_dims(self.entity_mask_1, 3) * head_tail_1),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_2 = tf.concat([head_2, tail_2], axis=3)
            head_tail_transformed_2 = tf.layers.dense(
                head_tail_2,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_2 = tf.layers.dense(relation_2,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_2 = tf.reduce_sum(relation_transformed_2 *
                                       head_tail_transformed_2,
                                       axis=3)
            alpha_weight_2 = tf.nn.softmax(e_weight_2)
            graph_embed_2 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_2, 3) *
                (tf.expand_dims(self.entity_mask_2, 3) * head_tail_2),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_3 = tf.concat([head_3, tail_3], axis=3)
            head_tail_transformed_3 = tf.layers.dense(
                head_tail_3,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_3 = tf.layers.dense(relation_3,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_3 = tf.reduce_sum(relation_transformed_3 *
                                       head_tail_transformed_3,
                                       axis=3)
            alpha_weight_3 = tf.nn.softmax(e_weight_3)
            graph_embed_3 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_3, 3) *
                (tf.expand_dims(self.entity_mask_3, 3) * head_tail_3),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_4 = tf.concat([head_4, tail_4], axis=3)
            head_tail_transformed_4 = tf.layers.dense(
                head_tail_4,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_4 = tf.layers.dense(relation_4,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_4 = tf.reduce_sum(relation_transformed_4 *
                                       head_tail_transformed_4,
                                       axis=3)
            alpha_weight_4 = tf.nn.softmax(e_weight_4)
            graph_embed_4 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_4, 3) *
                (tf.expand_dims(self.entity_mask_4, 3) * head_tail_4),
                axis=2)

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        output_fn, sampled_sequence_loss = output_projection_layer(
            num_units, num_symbols, num_samples)

        encoder_output_1, encoder_state_1 = dynamic_rnn(cell,
                                                        self.encoder_input_1,
                                                        self.posts_length_1,
                                                        dtype=tf.float32,
                                                        scope="encoder")

        attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1 \
                = attention_decoder_fn.prepare_attention(graph_embed_1, encoder_output_1, 'luong', num_units)
        decoder_fn_train_1 = attention_decoder_fn.attention_decoder_fn_train(
            encoder_state_1,
            attention_keys_1,
            attention_values_1,
            attention_score_fn_1,
            attention_construct_fn_1,
            max_length=tf.reduce_max(self.posts_length_2))
        encoder_output_2, encoder_state_2, alignments_ta_2 = dynamic_rnn_decoder(
            cell,
            decoder_fn_train_1,
            self.encoder_input_2,
            self.posts_length_2,
            scope="decoder")
        self.alignments_2 = tf.transpose(alignments_ta_2.stack(),
                                         perm=[1, 0, 2])

        self.decoder_loss_2 = sampled_sequence_loss(encoder_output_2,
                                                    self.posts_2_target,
                                                    self.encoder_2_mask)

        with variable_scope.variable_scope('', reuse=True):
            attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2 \
                    = attention_decoder_fn.prepare_attention(graph_embed_2, encoder_output_2, 'luong', num_units)
            decoder_fn_train_2 = attention_decoder_fn.attention_decoder_fn_train(
                encoder_state_2,
                attention_keys_2,
                attention_values_2,
                attention_score_fn_2,
                attention_construct_fn_2,
                max_length=tf.reduce_max(self.posts_length_3))
            encoder_output_3, encoder_state_3, alignments_ta_3 = dynamic_rnn_decoder(
                cell,
                decoder_fn_train_2,
                self.encoder_input_3,
                self.posts_length_3,
                scope="decoder")
            self.alignments_3 = tf.transpose(alignments_ta_3.stack(),
                                             perm=[1, 0, 2])

            self.decoder_loss_3 = sampled_sequence_loss(
                encoder_output_3, self.posts_3_target, self.encoder_3_mask)

            attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3 \
                    = attention_decoder_fn.prepare_attention(graph_embed_3, encoder_output_3, 'luong', num_units)
            decoder_fn_train_3 = attention_decoder_fn.attention_decoder_fn_train(
                encoder_state_3,
                attention_keys_3,
                attention_values_3,
                attention_score_fn_3,
                attention_construct_fn_3,
                max_length=tf.reduce_max(self.posts_length_4))
            encoder_output_4, encoder_state_4, alignments_ta_4 = dynamic_rnn_decoder(
                cell,
                decoder_fn_train_3,
                self.encoder_input_4,
                self.posts_length_4,
                scope="decoder")
            self.alignments_4 = tf.transpose(alignments_ta_4.stack(),
                                             perm=[1, 0, 2])

            self.decoder_loss_4 = sampled_sequence_loss(
                encoder_output_4, self.posts_4_target, self.encoder_4_mask)

            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                    = attention_decoder_fn.prepare_attention(graph_embed_4, encoder_output_4, 'luong', num_units)

        if is_train:
            with variable_scope.variable_scope('', reuse=True):
                decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
                    encoder_state_4,
                    attention_keys,
                    attention_values,
                    attention_score_fn,
                    attention_construct_fn,
                    max_length=tf.reduce_max(self.responses_length))
                self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(
                    cell,
                    decoder_fn_train,
                    self.decoder_input,
                    self.responses_length,
                    scope="decoder")
                self.alignments = tf.transpose(alignments_ta.stack(),
                                               perm=[1, 0, 2])

                self.decoder_loss = sampled_sequence_loss(
                    self.decoder_output, self.responses_target,
                    self.decoder_mask)

            self.params = tf.trainable_variables()

            self.learning_rate = tf.Variable(float(learning_rate),
                                             trainable=False,
                                             dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                self.learning_rate * learning_rate_decay_factor)
            self.global_step = tf.Variable(0, trainable=False)

            #opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9)

            gradients = tf.gradients(
                self.decoder_loss + self.decoder_loss_2 + self.decoder_loss_3 +
                self.decoder_loss_4, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
                gradients, max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients,
                                                  self.params),
                                              global_step=self.global_step)

        else:
            with variable_scope.variable_scope('', reuse=True):
                decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(
                    output_fn, encoder_state_4, attention_keys,
                    attention_values, attention_score_fn,
                    attention_construct_fn, self.embed, GO_ID, EOS_ID,
                    max_length, num_symbols)
                self.decoder_distribution, _, alignments_ta = dynamic_rnn_decoder(
                    cell, decoder_fn_inference, scope="decoder")
                output_len = tf.shape(self.decoder_distribution)[1]
                self.alignments = tf.transpose(
                    alignments_ta.gather(tf.range(output_len)), [1, 0, 2])

            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, num_symbols - 2],
                         2)[1], 2) + 2  # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols,
                                                     self.generation_index,
                                                     name="generation")

            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2,
                                    max_to_keep=10,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
コード例 #3
0
ファイル: model.py プロジェクト: streamride/seq2seq
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 num_layers,
                 beam_size,
                 embed,
                 learning_rate=0.5,
                 remove_unk=False,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 num_samples=512,
                 max_length=8,
                 use_lstm=False):

        self.posts = tf.placeholder(tf.string, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None),
                                           'enc_lens')  # batch
        self.responses = tf.placeholder(tf.string, (None, None),
                                        'dec_inps')  # batch*len
        self.responses_length = tf.placeholder(tf.int32, (None),
                                               'dec_lens')  # batch

        # initialize the training process
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)

        self.symbol2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=UNK_ID,
                                             shared_name="in_table",
                                             name="in_table",
                                             checkpoint=True)
        self.index2symbol = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_UNK',
                                             shared_name="out_table",
                                             name="out_table",
                                             checkpoint=True)
        # build the vocab table (string to index)

        self.posts_input = self.symbol2index.lookup(self.posts)  # batch*len
        self.responses_target = self.symbol2index.lookup(
            self.responses)  #batch*len

        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        self.responses_input = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(
            self.embed, self.posts_input)  #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        # rnn encoder
        encoder_output, encoder_state = dynamic_rnn(cell,
                                                    self.encoder_input,
                                                    self.posts_length,
                                                    dtype=tf.float32,
                                                    scope="encoder")

        # get output projection function
        output_fn, sampled_sequence_loss = output_projection_layer(
            num_units, num_symbols, num_samples)

        # get attention function
        attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units)

        with tf.variable_scope('decoder'):
            decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
                encoder_state, attention_keys, attention_values,
                attention_score_fn, attention_construct_fn)
            self.decoder_output, _, _ = dynamic_rnn_decoder(
                cell,
                decoder_fn_train,
                self.decoder_input,
                self.responses_length,
                scope="decoder_rnn")
            self.decoder_loss = sampled_sequence_loss(self.decoder_output,
                                                      self.responses_target,
                                                      self.decoder_mask)

        with tf.variable_scope('decoder', reuse=True):
            decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(
                output_fn, encoder_state, attention_keys, attention_values,
                attention_score_fn, attention_construct_fn, self.embed, GO_ID,
                EOS_ID, max_length, num_symbols)

            self.decoder_distribution, _, _ = dynamic_rnn_decoder(
                cell, decoder_fn_inference, scope="decoder_rnn")
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, num_symbols - 2],
                         2)[1], 2) + 2  # for removing UNK
            self.generation = self.index2symbol.lookup(self.generation_index,
                                                       name='generation')

        with tf.variable_scope('decoder', reuse=True):
            decoder_fn_beam_inference = attention_decoder_fn_beam_inference(
                output_fn, encoder_state, attention_keys, attention_values,
                attention_score_fn, attention_construct_fn, self.embed, GO_ID,
                EOS_ID, max_length, num_symbols, beam_size, remove_unk)
            _, _, self.context_state = dynamic_rnn_decoder(
                cell, decoder_fn_beam_inference, scope="decoder_rnn")
            (log_beam_probs, beam_parents, beam_symbols, result_probs,
             result_parents, result_symbols) = self.context_state

            self.beam_parents = tf.transpose(tf.reshape(
                beam_parents.stack(), [max_length + 1, -1, beam_size]),
                                             [1, 0, 2],
                                             name='beam_parents')
            self.beam_symbols = tf.transpose(
                tf.reshape(beam_symbols.stack(),
                           [max_length + 1, -1, beam_size]), [1, 0, 2])
            self.beam_symbols = self.index2symbol.lookup(tf.cast(
                self.beam_symbols, tf.int64),
                                                         name="beam_symbols")

            self.result_probs = tf.transpose(tf.reshape(
                result_probs.stack(), [max_length + 1, -1, beam_size * 2]),
                                             [1, 0, 2],
                                             name='result_probs')
            self.result_symbols = tf.transpose(
                tf.reshape(result_symbols.stack(),
                           [max_length + 1, -1, beam_size * 2]), [1, 0, 2])
            self.result_parents = tf.transpose(tf.reshape(
                result_parents.stack(), [max_length + 1, -1, beam_size * 2]),
                                               [1, 0, 2],
                                               name='result_parents')
            self.result_symbols = self.index2symbol.lookup(
                tf.cast(self.result_symbols, tf.int64), name='result_symbols')

        self.params = tf.trainable_variables()

        # calculate the gradient of parameters
        opt = tf.train.GradientDescentOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                    max_to_keep=3,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)

        # Exporter for serving
        self.model_exporter = exporter.Exporter(self.saver)
        inputs = {"enc_inps:0": self.posts, "enc_lens:0": self.posts_length}
        outputs = {
            "beam_symbols": self.beam_symbols,
            "beam_parents": self.beam_parents,
            "result_probs": self.result_probs,
            "result_symbols": self.result_symbols,
            "result_parents": self.result_parents
        }
        self.model_exporter.init(tf.get_default_graph().as_graph_def(),
                                 named_graph_signatures={
                                     "inputs":
                                     exporter.generic_signature(inputs),
                                     "outputs":
                                     exporter.generic_signature(outputs)
                                 })
コード例 #4
0
ファイル: model.py プロジェクト: juvu/seq2seq_cn
    def __init__(
            self,
            num_symbols,  # 词汇表size
            num_embed_units,  # 词嵌入size
            num_units,  # RNN 每层单元数
            num_layers,  # RNN 层数
            embed,  # 词嵌入
            entity_embed=None,  #
            num_entities=0,  #
            num_trans_units=100,  #
            learning_rate=0.0001,
            learning_rate_decay_factor=0.95,  #
            max_gradient_norm=5.0,  #
            num_samples=500,  # 样本个数,sampled softmax
            max_length=60,
            mem_use=True,
            output_alignments=True,
            use_lstm=False):

        self.posts = tf.placeholder(tf.string, (None, None),
                                    'enc_inps')  # batch_size * encoder_len
        self.posts_length = tf.placeholder(tf.int32, (None),
                                           'enc_lens')  # batch_size
        self.responses = tf.placeholder(tf.string, (None, None),
                                        'dec_inps')  # batch_size * decoder_len
        self.responses_length = tf.placeholder(tf.int32, (None),
                                               'dec_lens')  # batch_size
        self.entities = tf.placeholder(
            tf.string, (None, None, None),
            'entities')  # batch_size * triple_num * triple_len
        self.entity_masks = tf.placeholder(tf.string, (None, None),
                                           'entity_masks')  # 没用到
        self.triples = tf.placeholder(
            tf.string, (None, None, None, 3),
            'triples')  # batch_size * triple_num * triple_len * 3
        self.posts_triple = tf.placeholder(
            tf.int32, (None, None, 1),
            'enc_triples')  # batch_size * encoder_len
        self.responses_triple = tf.placeholder(
            tf.string, (None, None, 3),
            'dec_triples')  # batch_size * decoder_len * 3
        self.match_triples = tf.placeholder(
            tf.int32, (None, None, None),
            'match_triples')  # batch_size * decoder_len * triple_num

        # 获得 encoder_batch_size ,编码器的 encoder_len
        encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts))
        # 获得 triple_num
        # 每个 post 包含的知识图个数(补齐过的)
        triple_num = tf.shape(self.triples)[1]
        # 获得 triple_len
        # 每个知识图包含的关联实体个数(补齐过的)
        triple_len = tf.shape(self.triples)[2]

        # 使用的知识三元组
        one_hot_triples = tf.one_hot(
            self.match_triples,
            triple_len)  # batch_size * decoder_len * triple_num * triple_len
        # 用 1 标注了哪个时间步产生的回复用了知识三元组
        use_triples = tf.reduce_sum(one_hot_triples,
                                    axis=[2, 3])  # batch_size * decoder_len

        # 词汇映射到 index 的 hash table
        self.symbol2index = MutableHashTable(
            key_dtype=tf.string,  # key张量的类型
            value_dtype=tf.int64,  # value张量的类型
            default_value=UNK_ID,  # 缺少key的默认值
            shared_name=
            "in_table",  # If non-empty, this table will be shared under the given name across multiple sessions
            name="in_table",  # 操作名
            checkpoint=True
        )  # if True, the contents of the table are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name.

        # index 映射到词汇的 hash table
        self.index2symbol = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_UNK',
                                             shared_name="out_table",
                                             name="out_table",
                                             checkpoint=True)

        # 实体映射到 index 的 hash table
        self.entity2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=NONE_ID,
                                             shared_name="entity_in_table",
                                             name="entity_in_table",
                                             checkpoint=True)

        # index 映射到实体的 hash table
        self.index2entity = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_NONE',
                                             shared_name="entity_out_table",
                                             name="entity_out_table",
                                             checkpoint=True)

        # 将 post 的 string 映射成词汇 id
        self.posts_word_id = self.symbol2index.lookup(
            self.posts)  # batch_size * encoder_len
        # 将 post 的 string 映射成实体 id
        self.posts_entity_id = self.entity2index.lookup(
            self.posts)  # batch_size * encoder_len

        # 将 response 的 string 映射成词汇 id
        self.responses_target = self.symbol2index.lookup(
            self.responses)  # batch_size * decoder_len
        # 获得解码器的 batch_size,decoder_len
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        #  去掉 responses_target 的最后一列,给第一列加上 GO_ID
        self.responses_word_id = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # batch_size * decoder_len

        # 得到 response 的 mask
        # 首先将回复的长度 one_hot 编码
        # 然后横着从右向左累计求和,形成一个如果该位置在长度范围内,则为1,否则则为0的矩阵,最后一步 reshape 应该没有必要
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])  # batch_size * decoder_len

        # 初始化 词嵌入 和 实体嵌入,传入了参数就直接赋值,没有的话就随机初始化
        if embed is None:
            self.embed = tf.get_variable('word_embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            self.embed = tf.get_variable('word_embed',
                                         dtype=tf.float32,
                                         initializer=embed)
        if entity_embed is None:
            self.entity_trans = tf.get_variable(
                'entity_embed', [num_entities, num_trans_units],
                tf.float32,
                trainable=False)
        else:
            self.entity_trans = tf.get_variable('entity_embed',
                                                dtype=tf.float32,
                                                initializer=entity_embed,
                                                trainable=False)

        # 添加一个全连接层,输入是实体的嵌入,该层的 size=num_trans_units,激活函数是tanh
        # 为什么还要用全连接层连一下??????
        self.entity_trans_transformed = tf.layers.dense(
            self.entity_trans,
            num_trans_units,
            activation=tf.tanh,
            name='trans_transformation')
        # 7 * num_trans_units 的全零初始化的数组
        padding_entity = tf.get_variable('entity_padding_embed',
                                         [7, num_trans_units],
                                         dtype=tf.float32,
                                         initializer=tf.zeros_initializer())

        # 把 padding_entity 添加到 entity_trans_transformed 的最前,补了有什么用?????????????
        self.entity_embed = tf.concat(
            [padding_entity, self.entity_trans_transformed], axis=0)

        # tf.nn.embedding_lookup 以后维度会+1,所以通过reshape来取消这个多出来的维度
        triples_embedding = tf.reshape(
            tf.nn.embedding_lookup(self.entity_embed,
                                   self.entity2index.lookup(self.triples)),
            [encoder_batch_size, triple_num, -1, 3 * num_trans_units])
        entities_word_embedding = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entities)),
            [encoder_batch_size, -1, num_embed_units
             ])  # [batch_size,triple_num*triple_len,num_embed_units]

        # 把 head,relation,tail分割开来
        head, relation, tail = tf.split(triples_embedding,
                                        [num_trans_units] * 3,
                                        axis=3)

        # 静态图注意力机制
        with tf.variable_scope('graph_attention'):
            # 将头和尾连接起来
            head_tail = tf.concat(
                [head, tail],
                axis=3)  # batch_size * triple_num * triple_len * 200

            # tanh(dot(W, head_tail))
            head_tail_transformed = tf.layers.dense(
                head_tail,
                num_trans_units,
                activation=tf.tanh,
                name='head_tail_transform'
            )  # batch_size * triple_num * triple_len * 100

            # dot(W, relation)
            relation_transformed = tf.layers.dense(
                relation, num_trans_units, name='relation_transform'
            )  # batch_size * triple_num * triple_len * 100

            # 两个向量先元素乘,再求和,等于两个向量的内积
            # dot(traspose(dot(W, relation)), tanh(dot(W, head_tail)))
            e_weight = tf.reduce_sum(
                relation_transformed * head_tail_transformed,
                axis=3)  # batch_size * triple_num * triple_len

            # 图中每个三元组的 alpha 权值
            alpha_weight = tf.nn.softmax(
                e_weight)  # batch_size * triple_num * triple_len

            # tf.expand_dims 使 alpha_weight 维度+1 batch_size * triple_num * triple_len * 1
            # 对第2个维度求和,由此产生每个图 100 维的图向量表示
            graph_embed = tf.reduce_sum(
                tf.expand_dims(alpha_weight, 3) * head_tail,
                axis=2)  # batch_size * triple_num * 100
        """
        [0, 1, 2... encoder_batch_size] 转化成 encoder_batch_size * 1 * 1 的矩阵 [[[0]], [[1]], [[2]],...]
        tf.tile 将矩阵的第 1 维进行扩展 encoder_batch_size * encoder_len * 1 [[[0],[0]...]],...]
        与 posts_triple 在第 2 维度上进行拼接,形成 indices 矩阵
        indices 矩阵:
        [
         [[0 0], [0 0], [0 0], [0 0], [0 1], [0 0], [0 2], [0 0],...encoder_len],
         [[1 0], [1 0], [1 0], [1 0], [1 1], [1 0], [1 2], [1 0],...encoder_len],
         [[2 0], [2 0], [2 0], [2 0], [2 1], [2 0], [2 2], [2 0],...encoder_len]
         ,...batch_size
        ]
        tf.gather_nd 将 graph_embed 中根据上面矩阵提供的索引检索图向量,再回填至 indices 矩阵
        encoder_batch_size * encoder_len * 100
        """
        graph_embed_input = tf.gather_nd(
            graph_embed,
            tf.concat([
                tf.tile(
                    tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32),
                               [-1, 1, 1]), [1, encoder_len, 1]),
                self.posts_triple
            ],
                      axis=2))

        # 将 responses_triple 转化成实体嵌入 batch_size * decoder_len * 300
        triple_embed_input = tf.reshape(
            tf.nn.embedding_lookup(
                self.entity_embed,
                self.entity2index.lookup(self.responses_triple)),
            [batch_size, decoder_len, 3 * num_trans_units])

        # 将 posts_word_id 转化成词嵌入
        post_word_input = tf.nn.embedding_lookup(
            self.embed, self.posts_word_id)  # batch_size * encoder_len * 300

        # 将 responses_word_id 转化成词嵌入
        response_word_input = tf.nn.embedding_lookup(
            self.embed,
            self.responses_word_id)  # batch_size * decoder_len * 300

        # post_word_input, graph_embed_input 在第二个维度上拼接
        self.encoder_input = tf.concat(
            [post_word_input, graph_embed_input],
            axis=2)  # batch_size * encoder_len * 400
        # response_word_input, triple_embed_input 在第二个维度上拼接
        self.decoder_input = tf.concat(
            [response_word_input, triple_embed_input],
            axis=2)  # batch_size * decoder_len * 600

        # 构造 deep RNN
        encoder_cell = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])
        decoder_cell = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])

        # rnn encoder
        encoder_output, encoder_state = dynamic_rnn(encoder_cell,
                                                    self.encoder_input,
                                                    self.posts_length,
                                                    dtype=tf.float32,
                                                    scope="encoder")

        # 由于词汇表维度过大,所以输出的维度不可能和词汇表一样。通过 projection 函数,可以实现从低维向高维的映射
        # 返回:输出函数,选择器函数,计算序列损失,采样序列损失,总体损失的函数
        output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer(
            num_units, num_symbols, num_samples)

        # 用于训练的 decoder
        with tf.variable_scope('decoder'):
            # 得到注意力函数
            # 准备注意力
            # attention_keys_init: 注意力的 keys
            # attention_values_init: 注意力的 values
            # attention_score_fn_init: 计算注意力上下文的函数
            # attention_construct_fn_init: 计算所有上下文拼接的函数
            attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \
                    = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)#'luong', num_units)

            # 返回训练时解码器每一个时间步对输入的处理函数
            decoder_fn_train = attention_decoder_fn_train(
                encoder_state,
                attention_keys_init,
                attention_values_init,
                attention_score_fn_init,
                attention_construct_fn_init,
                output_alignments=output_alignments and mem_use,
                max_length=tf.reduce_max(self.responses_length))

            # 输出,最终状态,alignments 的 TensorArray
            self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(
                decoder_cell,
                decoder_fn_train,
                self.decoder_input,
                self.responses_length,
                scope="decoder_rnn")

            if output_alignments:

                self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(
                    self.decoder_output, self.responses_target,
                    self.decoder_mask, self.alignments, triples_embedding,
                    use_triples, one_hot_triples)
                self.sentence_ppx = tf.identity(
                    self.sentence_ppx,
                    name='ppx_loss')  # 将 sentence_ppx 转化成一步操作
            else:
                self.decoder_loss = sequence_loss(self.decoder_output,
                                                  self.responses_target,
                                                  self.decoder_mask)

        # 用于推导的 decoder
        with tf.variable_scope('decoder', reuse=True):
            # 得到注意力函数
            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                    = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)#'luong', num_units)
            decoder_fn_inference = attention_decoder_fn_inference(
                output_fn,
                encoder_state,
                attention_keys,
                attention_values,
                attention_score_fn,
                attention_construct_fn,
                self.embed,
                GO_ID,
                EOS_ID,
                max_length,
                num_symbols,
                imem=(entities_word_embedding,
                      tf.reshape(
                          triples_embedding,
                          [encoder_batch_size, -1, 3 * num_trans_units])),
                selector_fn=selector_fn)
            # imem: ([batch_size,triple_num*triple_len,num_embed_units],[encoder_batch_size, triple_num*triple_len, 3*num_trans_units]) 实体次嵌入和三元组嵌入的元组

            self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(
                decoder_cell, decoder_fn_inference, scope="decoder_rnn")

            output_len = tf.shape(self.decoder_distribution)[1]  # decoder_len
            output_ids = tf.transpose(
                output_ids_ta.gather(
                    tf.range(output_len)))  # [batch_size, decoder_len]

            # 对 output 的值域行裁剪
            word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols),
                               tf.int64)  # [batch_size, decoder_len]

            # 计算的是采用的实体词在 entities 的位置
            # 1、tf.shape(entities_word_embedding)[1] = triple_num*triple_len
            # 2、tf.range(encoder_batch_size): [batch_size]
            # 3、tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]): [batch_size, 1] 实体词在 entities 中的偏移量
            # 4、tf.clip_by_value(-output_ids, 0, num_symbols): [batch_size, decoder_len] 实体词的相对位置
            # 5、entity_ids: [batch_size * decoder_len] 加上偏移量之后在 entities 中的实际位置
            entity_ids = tf.reshape(
                tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(
                    tf.range(encoder_batch_size) *
                    tf.shape(entities_word_embedding)[1], [-1, 1]), [-1])

            # 计算的是所用的实体词
            # 1、entities: [batch_size, triple_num, triple_len]
            # 2、tf.reshape(self.entities, [-1]): [batch_size * triple_num * triple_len]
            # 3、tf.gather: [batch_size*decoder_len]
            # 4、entities: [batch_size, output_len]
            entities = tf.reshape(
                tf.gather(tf.reshape(self.entities, [-1]), entity_ids),
                [-1, output_len])

            words = self.index2symbol.lookup(word_ids)  # 将 id 转化为实际的词
            # output_ids > 0 为 bool 张量,True 的位置用 words 中该位置的词替换
            self.generation = tf.where(output_ids > 0, words, entities)
            self.generation = tf.identity(self.generation, name='generation')

        # 初始化训练过程
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)

        # ???
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)

        # 更新参数的次数
        self.global_step = tf.Variable(0, trainable=False)

        # 要训练的参数
        self.params = tf.global_variables()

        # 选择优化算法
        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)

        self.lr = opt._lr

        # 根据 decoder_loss 计算 params 梯度
        gradients = tf.gradients(self.decoder_loss, self.params)
        # 梯度裁剪
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        tf.summary.scalar('decoder_loss', self.decoder_loss)
        for each in tf.trainable_variables():
            tf.summary.histogram(each.name, each)

        self.merged_summary_op = tf.summary.merge_all()

        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                    max_to_keep=3,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
        self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                          max_to_keep=1000,
                                          pad_step_number=True)
コード例 #5
0
    def __init__(self,
            num_symbols,
            num_qwords, #modify
            num_embed_units,
            num_units,
            num_layers,
            is_train,
            vocab=None,
            embed=None,
            question_data=True,
            learning_rate=0.5,
            learning_rate_decay_factor=0.95,
            max_gradient_norm=5.0,
            num_samples=512,
            max_length=30,
            use_lstm=False):

        self.posts = tf.placeholder(tf.string, shape=(None, None))  # batch*len
        self.posts_length = tf.placeholder(tf.int32, shape=(None))  # batch
        self.responses = tf.placeholder(tf.string, shape=(None, None))  # batch*len
        self.responses_length = tf.placeholder(tf.int32, shape=(None))  # batch
        self.keyword_tensor = tf.placeholder(tf.float32, shape=(None, 3, None)) #(batch * len) * 3 * numsymbol
        self.word_type = tf.placeholder(tf.int32, shape=(None))   #(batch * len)

        # build the vocab table (string to index)
        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols,
            tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)),
            default_value=UNK_ID, name="symbol2index")
        self.posts_input = self.symbol2index.lookup(self.posts)   # batch*len
        self.responses_target = self.symbol2index.lookup(self.responses)   #batch*len
        
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1]
        self.responses_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID,
            tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1)   # batch*len
        #delete the last column of responses_target) and add 'GO at the front of it.
        self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1,
            decoder_len), reverse=True, axis=1), [-1, decoder_len]) # bacth * len

        print "embedding..."
        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32)
        else:
            print len(vocab), len(embed), len(embed[0])
            print embed
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_input) #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input)

        print "embedding finished"

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        # rnn encoder
        encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input,
                self.posts_length, dtype=tf.float32, scope="encoder")
        # get output projection function
        output_fn, sampled_sequence_loss = output_projection_layer(num_units,
                num_symbols, num_qwords, num_samples, question_data)

        print "encoder_output.shape:", encoder_output.get_shape()

        # get attention function
        attention_keys, attention_values, attention_score_fn, attention_construct_fn \
              = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units)

        # get decoding loop function
        decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(encoder_state,
                attention_keys, attention_values, attention_score_fn, attention_construct_fn)
        decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(output_fn,
                self.keyword_tensor,
                encoder_state, attention_keys, attention_values, attention_score_fn,
                attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols)

        if is_train:
            # rnn decoder
            self.decoder_output, _, _ = dynamic_rnn_decoder(cell, decoder_fn_train,
                    self.decoder_input, self.responses_length, scope="decoder")
            # calculate the loss of decoder
            # self.decoder_output = tf.Print(self.decoder_output, [self.decoder_output])
            self.decoder_loss, self.log_perplexity = sampled_sequence_loss(self.decoder_output,
                    self.responses_target, self.decoder_mask, self.keyword_tensor, self.word_type)

            # building graph finished and get all parameters
            self.params = tf.trainable_variables()

            for item in tf.trainable_variables():
                print item.name, item.get_shape()

            # initialize the training process
            self.learning_rate = tf.Variable(float(learning_rate), trainable=False,
                    dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                    self.learning_rate * learning_rate_decay_factor)

            self.global_step = tf.Variable(0, trainable=False)

            # calculate the gradient of parameters

            opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            gradients = tf.gradients(self.decoder_loss, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients,
                    max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                    global_step=self.global_step)

        else:
            # rnn decoder
            self.decoder_distribution, _, _ = dynamic_rnn_decoder(cell, decoder_fn_inference,
                    scope="decoder")
            print("self.decoder_distribution.shape():",self.decoder_distribution.get_shape())
            self.decoder_distribution = tf.Print(self.decoder_distribution, ["distribution.shape()", tf.reduce_sum(self.decoder_distribution)])
            # generating the response
            self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
                [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index)

            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2,
                max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
コード例 #6
0
    def __init__(self,
            num_symbols,
            num_embed_units,
            num_units,
            is_train,
            vocab=None,
            content_pos=None,
            rhetoric_pos = None,
            embed=None,
            learning_rate=0.1,
            learning_rate_decay_factor=0.9995,
            max_gradient_norm=5.0,
            max_length=30,
            latent_size=128,
            use_lstm=False,
            num_classes=3,
            full_kl_step=80000,
            mem_slot_num=4,
            mem_size=128):
        
        self.ori_sents = tf.placeholder(tf.string, shape=(None, None))
        self.ori_sents_length = tf.placeholder(tf.int32, shape=(None))
        self.rep_sents = tf.placeholder(tf.string, shape=(None, None))
        self.rep_sents_length = tf.placeholder(tf.int32, shape=(None))
        self.labels = tf.placeholder(tf.float32, shape=(None, num_classes))
        self.use_prior = tf.placeholder(tf.bool)
        self.global_t = tf.placeholder(tf.int32)
        self.content_mask = tf.reduce_sum(tf.one_hot(content_pos, num_symbols, 1.0, 0.0), axis = 0)
        self.rhetoric_mask = tf.reduce_sum(tf.one_hot(rhetoric_pos, num_symbols, 1.0, 0.0), axis = 0)

        topic_memory = tf.zeros(name="topic_memory", dtype=tf.float32,
                                  shape=[None, mem_slot_num, mem_size])

        w_topic_memory = tf.get_variable(name="w_topic_memory", dtype=tf.float32,
                                    initializer=tf.random_uniform([mem_size, mem_size], -0.1, 0.1))

        # build the vocab table (string to index)
        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, 
            tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), 
            default_value=UNK_ID, name="symbol2index")

        self.ori_sents_input = self.symbol2index.lookup(self.ori_sents)
        self.rep_sents_target = self.symbol2index.lookup(self.rep_sents)
        batch_size, decoder_len = tf.shape(self.rep_sents)[0], tf.shape(self.rep_sents)[1]
        self.rep_sents_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID,
            tf.split(self.rep_sents_target, [decoder_len-1, 1], 1)[0]], 1)
        self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.rep_sents_length-1,
            decoder_len), reverse=True, axis=1), [-1, decoder_len])        
        
        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed)

        self.pattern_embed = tf.get_variable('pattern_embed', [num_classes, num_embed_units], tf.float32)
        
        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.ori_sents_input)
        self.decoder_input = tf.nn.embedding_lookup(self.embed, self.rep_sents_input)

        if use_lstm:
            cell_fw = LSTMCell(num_units)
            cell_bw = LSTMCell(num_units)
            cell_dec = LSTMCell(2*num_units)
        else:
            cell_fw = GRUCell(num_units)
            cell_bw = GRUCell(num_units)
            cell_dec = GRUCell(2*num_units)

        # origin sentence encoder
        with variable_scope.variable_scope("encoder"):
            encoder_output, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.encoder_input, 
                self.ori_sents_length, dtype=tf.float32)
            post_sum_state = tf.concat(encoder_state, 1)
            encoder_output = tf.concat(encoder_output, 2)

        # response sentence encoder
        with variable_scope.variable_scope("encoder", reuse = True):
            decoder_state, decoder_last_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.decoder_input, 
                self.rep_sents_length, dtype=tf.float32)
            response_sum_state = tf.concat(decoder_last_state, 1)

        # recognition network
        with variable_scope.variable_scope("recog_net"):
            recog_input = tf.concat([post_sum_state, response_sum_state], 1)
            recog_mulogvar = tf.contrib.layers.fully_connected(recog_input, latent_size * 2, activation_fn=None, scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        # prior network
        with variable_scope.variable_scope("prior_net"):
            prior_fc1 = tf.contrib.layers.fully_connected(post_sum_state, latent_size * 2, activation_fn=tf.tanh, scope="fc1")
            prior_mulogvar = tf.contrib.layers.fully_connected(prior_fc1, latent_size * 2, activation_fn=None, scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

        latent_sample = tf.cond(self.use_prior,
                                lambda: sample_gaussian(prior_mu, prior_logvar),
                                lambda: sample_gaussian(recog_mu, recog_logvar))


        # classifier
        with variable_scope.variable_scope("classifier"):
            classifier_input = latent_sample
            pattern_fc1 = tf.contrib.layers.fully_connected(classifier_input, latent_size, activation_fn=tf.tanh, scope="pattern_fc1")
            self.pattern_logits = tf.contrib.layers.fully_connected(pattern_fc1, num_classes, activation_fn=None, scope="pattern_logits")

        self.label_embedding = tf.matmul(self.labels, self.pattern_embed)

        output_fn, my_sequence_loss = output_projection_layer(2*num_units, num_symbols, latent_size, num_embed_units, self.content_mask, self.rhetoric_mask)

        attention_keys, attention_values, attention_score_fn, attention_construct_fn = my_attention_decoder_fn.prepare_attention(encoder_output, 'luong', 2*num_units)

        with variable_scope.variable_scope("dec_start"):
            temp_start = tf.concat([post_sum_state, self.label_embedding, latent_sample], 1)
            dec_fc1 = tf.contrib.layers.fully_connected(temp_start, 2*num_units, activation_fn=tf.tanh, scope="dec_start_fc1")
            dec_fc2 = tf.contrib.layers.fully_connected(dec_fc1, 2*num_units, activation_fn=None, scope="dec_start_fc2")

        if is_train:
            # rnn decoder
            topic_memory = self.update_memory(topic_memory, encoder_output)
            extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1)

            decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train(dec_fc2, 
                attention_keys, attention_values, attention_score_fn, attention_construct_fn, extra_info)
            self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_train, 
                self.decoder_input, self.rep_sents_length, scope = "decoder")

            # calculate the loss
            self.decoder_loss = my_loss.sequence_loss(logits = self.decoder_output, 
                targets = self.rep_sents_target, weights = self.decoder_mask,
                extra_information = latent_sample, label_embedding = self.label_embedding, softmax_loss_function = my_sequence_loss)
            temp_klloss = tf.reduce_mean(gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar))
            self.kl_weight = tf.minimum(tf.to_float(self.global_t)/full_kl_step, 1.0)
            self.klloss = self.kl_weight * temp_klloss
            temp_labels = tf.argmax(self.labels, 1)
            self.classifierloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.pattern_logits, labels=temp_labels))
            self.loss = self.decoder_loss + self.klloss + self.classifierloss  # need to anneal the kl_weight
            
            # building graph finished and get all parameters
            self.params = tf.trainable_variables()
        
            # initialize the training process
            self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor)
            self.global_step = tf.Variable(0, trainable=False)
            
            # calculate the gradient of parameters
            opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
            gradients = tf.gradients(self.loss, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 
                    max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 
                    global_step=self.global_step)

        else:
            # rnn decoder
            topic_memory = self.update_memory(topic_memory, encoder_output)
            extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1)
            decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference(output_fn, 
                dec_fc2, attention_keys, attention_values, attention_score_fn, 
                attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, extra_info)
            self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_inference, scope="decoder")
            self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
                [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index)
            
            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, 
                max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
コード例 #7
0
    def __init__(self,
                 num_items,
                 num_embed_units,
                 num_units,
                 num_layers,
                 embed=None,
                 learning_rate=1e-4,
                 action_num=10,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 use_lstm=True):

        self.epoch = tf.Variable(0, trainable=False, name='agn/epoch')
        self.epoch_add_op = self.epoch.assign(self.epoch + 1)

        self.sessions_input = tf.placeholder(tf.int32, shape=(None, None))
        self.rec_lists = tf.placeholder(tf.int32, shape=(None, None, None))
        self.rec_mask = tf.placeholder(tf.float32, shape=(None, None, None))
        self.aims_idx = tf.placeholder(tf.int32, shape=(None, None))
        self.sessions_length = tf.placeholder(tf.int32, shape=(None))
        self.reward = tf.placeholder(tf.float32, shape=(None))

        if embed is None:
            self.embed = tf.get_variable(
                'agn/embed', [num_items, num_embed_units],
                tf.float32,
                initializer=tf.truncated_normal_initializer(0, 1))
        else:
            self.embed = tf.get_variable('agn/embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        batch_size, encoder_length, rec_length = tf.shape(
            self.sessions_input)[0], tf.shape(
                self.sessions_input)[1], tf.shape(self.rec_lists)[2]

        encoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.sessions_length - 2, encoder_length),
                      reverse=True,
                      axis=1), [-1, encoder_length])
        # [batch_size, length]
        self.sessions_target = tf.concat([
            self.sessions_input[:, 1:],
            tf.ones([batch_size, 1], dtype=tf.int32) * PAD_ID
        ], 1)
        # [batch_size, length, embed_units]
        self.encoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.sessions_input)
        # [batch_size, length, rec_length]
        self.aims = tf.one_hot(self.aims_idx, rec_length)

        if use_lstm:
            cell = MultiRNNCell(
                [LSTMCell(num_units) for _ in range(num_layers)])
        else:
            cell = MultiRNNCell(
                [GRUCell(num_units) for _ in range(num_layers)])

        # Training
        with tf.variable_scope("agn"):
            output_fn, sampled_sequence_loss = output_projection_layer(
                num_units, num_items)
            self.encoder_output, self.encoder_state = dynamic_rnn(
                cell,
                self.encoder_input,
                self.sessions_length,
                dtype=tf.float32,
                scope="encoder")

            tmp_dim_1 = tf.tile(
                tf.reshape(tf.range(batch_size), [batch_size, 1, 1, 1]),
                [1, encoder_length, rec_length, 1])
            tmp_dim_2 = tf.tile(
                tf.reshape(tf.range(encoder_length),
                           [1, encoder_length, 1, 1]),
                [batch_size, 1, rec_length, 1])
            # [batch_size, length, rec_length, 3]
            gather_idx = tf.concat(
                [tmp_dim_1, tmp_dim_2,
                 tf.expand_dims(self.rec_lists, 3)], 3)

            # [batch_size, length, num_items], [batch_size*length]
            y_prob, local_loss, total_size = sampled_sequence_loss(
                self.encoder_output, self.sessions_target, encoder_mask)

            # Compute recommendation rank given rec_list
            # [batch_size, length, num_items]
            y_prob = tf.reshape(y_prob, [batch_size, encoder_length, num_items]) * \
                tf.concat([tf.zeros([batch_size, encoder_length, 2], dtype=tf.float32),
                            tf.ones([batch_size, encoder_length, num_items-2], dtype=tf.float32)], 2)
            # [batch_size, length, rec_len]
            ini_prob = tf.reshape(tf.gather_nd(y_prob, gather_idx),
                                  [batch_size, encoder_length, rec_length])
            # [batch_size, length, rec_len]
            mul_prob = ini_prob * self.rec_mask

            # [batch_size, length, action_num]
            _, self.index = tf.nn.top_k(mul_prob, k=action_num)
            # [batch_size, length, metric_num]
            _, self.metric_index = tf.nn.top_k(mul_prob,
                                               k=(FLAGS['metric'].value + 1))

            self.loss = tf.reduce_sum(
                tf.reshape(self.reward, [-1]) * local_loss) / total_size

        # Inference
        with tf.variable_scope("agn", reuse=True):
            # tf.get_variable_scope().reuse_variables()
            self.lstm_state = tf.placeholder(tf.float32,
                                             shape=(2, 2, None, num_units))
            self.ini_state = (tf.contrib.rnn.LSTMStateTuple(
                self.lstm_state[0, 0, :, :], self.lstm_state[0, 1, :, :]),
                              tf.contrib.rnn.LSTMStateTuple(
                                  self.lstm_state[1, 0, :, :],
                                  self.lstm_state[1, 1, :, :]))
            # [batch_size, length, num_units]
            self.encoder_output_predict, self.encoder_state_predict = dynamic_rnn(
                cell,
                self.encoder_input,
                self.sessions_length,
                initial_state=self.ini_state,
                dtype=tf.float32,
                scope="encoder")

            # [batch_size, num_units]
            self.final_output_predict = tf.reshape(
                self.encoder_output_predict[:, -1, :], [-1, num_units])
            # [batch_size, num_items]
            self.rec_logits = output_fn(self.final_output_predict)
            # [batch_size, action_num]
            _, self.rec_index = tf.nn.top_k(
                self.rec_logits[:, len(_START_VOCAB):], action_num)
            self.rec_index += len(_START_VOCAB)

            def gumbel_max(inp, alpha, beta):
                # assert len(tf.shape(inp)) == 2
                g = tf.random_uniform(tf.shape(inp), 0.0001, 0.9999)
                g = -tf.log(-tf.log(g))
                inp_g = tf.nn.softmax(
                    (tf.nn.log_softmax(inp / 1.0) + g * alpha) * beta)
                return inp_g

            # [batch_size, action_num]
            _, self.random_rec_index = tf.nn.top_k(
                gumbel_max(self.rec_logits[:, len(_START_VOCAB):], 1, 1),
                action_num)
            self.random_rec_index += len(_START_VOCAB)

        # initialize the training process
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)

        self.global_step = tf.Variable(0, trainable=False)
        self.params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = tf.train.AdamOptimizer(
            self.learning_rate).apply_gradients(zip(clipped_gradients,
                                                    self.params),
                                                global_step=self.global_step)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2,
                                    max_to_keep=100,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
コード例 #8
0
    def __init__(
            self,
            num_symbols,  # 词汇表size
            num_embed_units,  # 词嵌入size
            num_units,  # RNN 每层单元数
            num_layers,  # RNN 层数
            embed,  # 词嵌入
            entity_embed=None,  # 实体+关系的嵌入
            num_entities=0,  # 实体+关系的总个数
            num_trans_units=100,  # 实体嵌入的维度
            memory_units=100,
            learning_rate=0.0001,  # 学习率
            learning_rate_decay_factor=0.95,  # 学习率衰退,并没有采用这种方式
            max_gradient_norm=5.0,  #
            num_samples=500,  # 样本个数,sampled softmax
            max_length=60,
            mem_use=True,
            output_alignments=True,
            use_lstm=False):

        self.posts = tf.placeholder(tf.string, (None, None),
                                    'enc_inps')  # [batch_size, encoder_len]
        self.posts_length = tf.placeholder(tf.int32, (None),
                                           'enc_lens')  # [batch_size]
        self.responses = tf.placeholder(
            tf.string, (None, None), 'dec_inps')  # [batch_size, decoder_len]
        self.responses_length = tf.placeholder(tf.int32, (None),
                                               'dec_lens')  # [batch_size]
        self.entities = tf.placeholder(
            tf.string, (None, None, None),
            'entities')  # [batch_size, triple_num, triple_len]
        self.entity_masks = tf.placeholder(tf.string, (None, None),
                                           'entity_masks')  # 没用到
        self.triples = tf.placeholder(
            tf.string, (None, None, None, 3),
            'triples')  # [batch_size, triple_num, triple_len, 3]
        self.posts_triple = tf.placeholder(
            tf.int32, (None, None, 1),
            'enc_triples')  # [batch_size, encoder_len, 1]
        self.responses_triple = tf.placeholder(
            tf.string, (None, None, 3),
            'dec_triples')  # [batch_size, decoder_len, 3]
        self.match_triples = tf.placeholder(
            tf.int32, (None, None, None),
            'match_triples')  # [batch_size, decoder_len, triple_num]

        # 编码器batch_size,编码器encoder_len
        encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts))
        triple_num = tf.shape(self.triples)[1]  # 知识图个数
        triple_len = tf.shape(self.triples)[2]  # 知识三元组个数

        # 使用的知识三元组
        one_hot_triples = tf.one_hot(
            self.match_triples,
            triple_len)  # [batch_size, decoder_len, triple_num, triple_len]
        # 用 1 标注了哪个时间步产生的回复用了知识三元组
        use_triples = tf.reduce_sum(one_hot_triples,
                                    axis=[2, 3])  # [batch_size, decoder_len]

        # 词汇映射到index的hash table
        self.symbol2index = MutableHashTable(
            key_dtype=tf.string,  # key张量的类型
            value_dtype=tf.int64,  # value张量的类型
            default_value=UNK_ID,  # 缺少key的默认值
            shared_name=
            "in_table",  # If non-empty, this table will be shared under the given name across multiple sessions
            name="in_table",  # 操作名
            checkpoint=True
        )  # if True, the contents of the table are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name.

        # index映射到词汇的hash table
        self.index2symbol = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_UNK',
                                             shared_name="out_table",
                                             name="out_table",
                                             checkpoint=True)

        # 实体映射到index的hash table
        self.entity2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=NONE_ID,
                                             shared_name="entity_in_table",
                                             name="entity_in_table",
                                             checkpoint=True)

        # index映射到实体的hash table
        self.index2entity = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_NONE',
                                             shared_name="entity_out_table",
                                             name="entity_out_table",
                                             checkpoint=True)

        self.posts_word_id = self.symbol2index.lookup(
            self.posts)  # [batch_size, encoder_len]
        self.posts_entity_id = self.entity2index.lookup(
            self.posts)  # [batch_size, encoder_len]

        self.responses_target = self.symbol2index.lookup(
            self.responses)  # [batch_size, decoder_len]
        # 获得解码器的batch_size,decoder_len
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        # 去掉responses_target的最后一列,给第一列加上GO_ID
        self.responses_word_id = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # [batch_size, decoder_len]

        # 得到response的mask
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])  # [batch_size, decoder_len]

        # 初始化词嵌入和实体嵌入,传入了参数就直接赋值,没有的话就随机初始化
        if embed is None:
            self.embed = tf.get_variable('word_embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            self.embed = tf.get_variable('word_embed',
                                         dtype=tf.float32,
                                         initializer=embed)
        if entity_embed is None:  # 实体嵌入不随着模型的训练而更新
            self.entity_trans = tf.get_variable(
                'entity_embed', [num_entities, num_trans_units],
                tf.float32,
                trainable=False)
        else:
            self.entity_trans = tf.get_variable('entity_embed',
                                                dtype=tf.float32,
                                                initializer=entity_embed,
                                                trainable=False)

        # 将实体嵌入传入一个全连接层
        self.entity_trans_transformed = tf.layers.dense(
            self.entity_trans,
            num_trans_units,
            activation=tf.tanh,
            name='trans_transformation')
        # 添加['_NONE', '_PAD_H', '_PAD_R', '_PAD_T', '_NAF_H', '_NAF_R', '_NAF_T']这7个的嵌入
        padding_entity = tf.get_variable('entity_padding_embed',
                                         [7, num_trans_units],
                                         dtype=tf.float32,
                                         initializer=tf.zeros_initializer())
        self.entity_embed = tf.concat(
            [padding_entity, self.entity_trans_transformed], axis=0)

        # triples_embedding: [batch_size, triple_num, triple_len, 3*num_trans_units] 知识图三元组的嵌入
        triples_embedding = tf.reshape(
            tf.nn.embedding_lookup(self.entity_embed,
                                   self.entity2index.lookup(self.triples)),
            [encoder_batch_size, triple_num, -1, 3 * num_trans_units])
        # entities_word_embedding: [batch_size, triple_num*triple_len, num_embed_units] 知识图中用到的所有实体的嵌入
        entities_word_embedding = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entities)),
            [encoder_batch_size, -1, num_embed_units])
        # 分离知识图三元组的头、关系和尾 [batch_size, triple_num, triple_len, num_trans_units]
        head, relation, tail = tf.split(triples_embedding,
                                        [num_trans_units] * 3,
                                        axis=3)

        # 静态图注意力机制
        with tf.variable_scope('graph_attention'):
            # 将头尾连接起来 [batch_size, triple_num, triple_len, 2*num_trans_units]
            head_tail = tf.concat([head, tail], axis=3)
            # 将头尾送入全连接层 [batch_size, triple_num, triple_len, num_trans_units]
            head_tail_transformed = tf.layers.dense(head_tail,
                                                    num_trans_units,
                                                    activation=tf.tanh,
                                                    name='head_tail_transform')
            # 将关系送入全连接层 [batch_size, triple_num, triple_len, num_trans_units]
            relation_transformed = tf.layers.dense(relation,
                                                   num_trans_units,
                                                   name='relation_transform')
            # 求头尾和关系两个向量的内积,获得对三元组的注意力系数
            e_weight = tf.reduce_sum(
                relation_transformed * head_tail_transformed,
                axis=3)  # [batch_size, triple_num, triple_len]
            alpha_weight = tf.nn.softmax(
                e_weight)  # [batch_size, triple_num, triple_len]
            # tf.expand_dims 使 alpha_weight 维度+1 [batch_size, triple_num, triple_len, 1]
            # 对第2个维度求和,由此产生静态图的向量表示
            graph_embed = tf.reduce_sum(
                tf.expand_dims(alpha_weight, 3) * head_tail,
                axis=2)  # [batch_size, triple_num, 2*num_trans_units]
        """graph_embed_input
        1、首先一维的range列表[0, 1, 2... encoder_batch_size个]转化成三维的[encoder_batch_size, 1, 1]的矩阵
        [[[0]], [[1]], [[2]],...]
        2、然后tf.tile将矩阵的第1维复制encoder_len遍,变成[encoder_batch_size, encoder_len, 1]
        [[[0],[0]...]],...]
        3、与posts_triple: [batch_size, encoder_len, 1]在第2维上进行拼接,形成一个indices: [batch_size, encoder_len, 2]矩阵,
        indices矩阵:
        [
         [[0 0], [0 0], [0 0], [0 0], [0 1], [0 0], [0 2], [0 0],...encoder_len],
         [[1 0], [1 0], [1 0], [1 0], [1 1], [1 0], [1 2], [1 0],...encoder_len],
         [[2 0], [2 0], [2 0], [2 0], [2 1], [2 0], [2 2], [2 0],...encoder_len]
         ,...batch_size
        ]
        4、tf.gather_nd根据索引检索graph_embed: [batch_size, triple_num, 2*num_trans_units]再回填至indices矩阵
        indices矩阵最后一个维度是2,例如有[0, 2],表示这个时间步第1个batch用了第2个图,
        则找到这个知识图的静态图向量填入到indices矩阵的[0, 2]位置最后得到结果维度
        [encoder_batch_size, encoder_len, 2*num_trans_units]表示每个时间步用的静态图向量
        """
        # graph_embed_input = tf.gather_nd(graph_embed, tf.concat(
        #     [tf.tile(tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]),
        #      self.posts_triple],
        #     axis=2))

        # 将responses_triple转化成实体嵌入 [batch_size, decoder_len, 300],标识了response每个时间步用了哪个三元组的嵌入
        # triple_embed_input = tf.reshape(
        #     tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.responses_triple)),
        #     [batch_size, decoder_len, 3 * num_trans_units])

        post_word_input = tf.nn.embedding_lookup(
            self.embed,
            self.posts_word_id)  # [batch_size, encoder_len, num_embed_units]
        response_word_input = tf.nn.embedding_lookup(
            self.embed, self.responses_word_id
        )  # [batch_size, decoder_len, num_embed_units]

        # post_word_input和graph_embed_input拼接构成编码器输入 [batch_size, encoder_len, num_embed_units+2*num_trans_units]
        # self.encoder_input = tf.concat([post_word_input, graph_embed_input], axis=2)
        # response_word_input和triple_embed_input拼接构成解码器输入 [batch_size, decoder_len, num_embed_units+3*num_trans_units]
        # self.decoder_input = tf.concat([response_word_input, triple_embed_input], axis=2)

        encoder_cell = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])
        decoder_cell = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])

        # rnn encoder
        # encoder_state: [num_layers, 2, batch_size, num_units] 编码器输出状态 LSTM GRU:[num_layers, batch_size, num_units]
        encoder_output, encoder_state = tf.nn.dynamic_rnn(encoder_cell,
                                                          post_word_input,
                                                          self.posts_length,
                                                          dtype=tf.float32,
                                                          scope="encoder")

        # self.encoder_state_shape = tf.shape(encoder_state)

        ########记忆网络                                                                                                     ###
        response_encoder_cell = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])
        response_encoder_output, response_encoder_state = tf.nn.dynamic_rnn(
            response_encoder_cell,
            response_word_input,
            self.responses_length,
            dtype=tf.float32,
            scope="response_encoder")

        # graph_embed: [batch_size, triple_num, 2*num_trans_units] 静态图向量
        # encoder_state: [num_layers, batch_size, num_units]
        with tf.variable_scope("post_memory_network"):
            # 将静态知识图转化成输入向量m
            post_input = tf.layers.dense(graph_embed,
                                         memory_units,
                                         use_bias=False,
                                         name="post_weight_a")
            post_input = tf.tile(
                tf.reshape(post_input,
                           (1, encoder_batch_size, triple_num, memory_units)),
                multiples=(
                    num_layers, 1, 1,
                    1))  # [num_layers, batch_size, triple_num, memory_units]
            # 将静态知识库转化成输出向量c
            post_output = tf.layers.dense(graph_embed,
                                          memory_units,
                                          use_bias=False,
                                          name="post_weight_c")
            post_output = tf.tile(
                tf.reshape(post_output,
                           (1, encoder_batch_size, triple_num, memory_units)),
                multiples=(
                    num_layers, 1, 1,
                    1))  # [num_layers, batch_size, triple_num, memory_units]
            # 将question转化成状态向量u
            encoder_hidden_state = tf.reshape(
                tf.concat(encoder_state,
                          axis=0), (num_layers, encoder_batch_size, num_units))
            post_state = tf.layers.dense(encoder_hidden_state,
                                         memory_units,
                                         use_bias=False,
                                         name="post_weight_b")
            post_state = tf.tile(
                tf.reshape(post_state,
                           (num_layers, encoder_batch_size, 1, memory_units)),
                multiples=(
                    1, 1, triple_num,
                    1))  # [num_layers, batch_size, triple_num, memory_units]
            # 概率p
            post_p = tf.reshape(
                tf.nn.softmax(tf.reduce_sum(post_state * post_input, axis=3)),
                (num_layers, encoder_batch_size, triple_num,
                 1))  # [num_layers, batch_size, triple_num, 1]
            # 输出o
            post_o = tf.reduce_sum(
                post_output * post_p,
                axis=2)  # [num_layers, batch_size, memory_units]
            post_xstar = tf.concat(
                [
                    tf.layers.dense(post_o,
                                    memory_units,
                                    use_bias=False,
                                    name="post_weight_r"), encoder_state
                ],
                axis=2)  # [num_layers, batch_size, num_units+memory_units]

        with tf.variable_scope("response_memory_network"):
            # 将静态知识图转化成输入向量m
            response_input = tf.layers.dense(graph_embed,
                                             memory_units,
                                             use_bias=False,
                                             name="response_weight_a")
            response_input = tf.tile(
                tf.reshape(response_input,
                           (1, batch_size, triple_num, memory_units)),
                multiples=(
                    num_layers, 1, 1,
                    1))  # [num_layers, batch_size, triple_num, memory_units]
            # 将静态知识库转化成输出向量c
            response_output = tf.layers.dense(graph_embed,
                                              memory_units,
                                              use_bias=False,
                                              name="response_weight_c")
            response_output = tf.tile(
                tf.reshape(response_output,
                           (1, batch_size, triple_num, memory_units)),
                multiples=(
                    num_layers, 1, 1,
                    1))  # [num_layers, batch_size, triple_num, memory_units]
            # 将question转化成状态向量u
            response_hidden_state = tf.reshape(
                tf.concat(response_encoder_state, axis=0),
                (num_layers, batch_size, num_units))
            response_state = tf.layers.dense(response_hidden_state,
                                             memory_units,
                                             use_bias=False,
                                             name="response_weight_b")
            response_state = tf.tile(
                tf.reshape(response_state,
                           (num_layers, batch_size, 1, memory_units)),
                multiples=(
                    1, 1, triple_num,
                    1))  # [num_layers, batch_size, triple_num, memory_units]
            # 概率p
            response_p = tf.reshape(
                tf.nn.softmax(
                    tf.reduce_sum(response_state * response_input, axis=3)),
                (num_layers, batch_size, triple_num,
                 1))  # [num_layers, batch_size, triple_num, 1]
            # 输出o
            response_o = tf.reduce_sum(
                response_output * response_p,
                axis=2)  # [num_layers, batch_size, memory_units]
            response_ystar = tf.concat(
                [
                    tf.layers.dense(response_o,
                                    memory_units,
                                    use_bias=False,
                                    name="response_weight_r"),
                    response_encoder_state
                ],
                axis=2)  # [num_layers, batch_size, num_units+memory_units]

        with tf.variable_scope("memory_network"):
            memory_hidden_state = tf.layers.dense(tf.concat(
                [post_xstar, response_ystar], axis=2),
                                                  num_units,
                                                  use_bias=False,
                                                  activation=tf.tanh,
                                                  name="output_weight")
            memory_hidden_state = tf.reshape(
                memory_hidden_state, (num_layers * batch_size, num_units))
            # [num_layers, batch_size, num_units]
            memory_hidden_state = tuple(
                tf.split(memory_hidden_state, [batch_size] * num_layers,
                         axis=0))
            # self.memory_hidden_state_shape = tf.shape(memory_hidden_state)
########                                                                                                             ###

        output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss =\
            output_projection_layer(num_units, num_symbols, num_samples)

        ########用于训练的decoder                                                                                            ###
        with tf.variable_scope('decoder'):
            attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \
                    = prepare_attention(encoder_output,
                                        'bahdanau',
                                        num_units,
                                        imem=(graph_embed, triples_embedding),
                                        output_alignments=output_alignments and mem_use)

            # 训练时处理每个时间步输出和下个时间步输入的函数
            decoder_fn_train = attention_decoder_fn_train(
                memory_hidden_state,
                attention_keys_init,
                attention_values_init,
                attention_score_fn_init,
                attention_construct_fn_init,
                output_alignments=output_alignments and mem_use,
                max_length=tf.reduce_max(self.responses_length))

            self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(
                decoder_cell,
                decoder_fn_train,
                response_word_input,
                self.responses_length,
                scope="decoder_rnn")

            if output_alignments:
                self.alignments = tf.transpose(alignments_ta.stack(),
                                               perm=[1, 0, 2, 3])
                self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(
                    self.decoder_output, self.responses_target,
                    self.decoder_mask, self.alignments, triples_embedding,
                    use_triples, one_hot_triples)
                self.sentence_ppx = tf.identity(self.sentence_ppx,
                                                name='ppx_loss')
            else:
                self.decoder_loss = sequence_loss(self.decoder_output,
                                                  self.responses_target,
                                                  self.decoder_mask)
########                                                                                                             ###
########用于推导的decoder                                                                                            ###
        with tf.variable_scope('decoder', reuse=True):
            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                    = prepare_attention(encoder_output,
                                        'bahdanau',
                                        num_units,
                                        reuse=True,
                                        imem=(graph_embed, triples_embedding),
                                        output_alignments=output_alignments and mem_use)

            decoder_fn_inference = \
                attention_decoder_fn_inference(output_fn,
                                               memory_hidden_state,
                                               attention_keys,
                                               attention_values,
                                               attention_score_fn,
                                               attention_construct_fn,
                                               self.embed,
                                               GO_ID,
                                               EOS_ID,
                                               max_length,
                                               num_symbols,
                                               imem=(entities_word_embedding,  # imem: ([batch_size,triple_num*triple_len,num_embed_units],
                                                     tf.reshape(triples_embedding, [encoder_batch_size, -1, 3*num_trans_units])),  # [encoder_batch_size, triple_num*triple_len, 3*num_trans_units]) 实体词嵌入和三元组嵌入的元组
                                               selector_fn=selector_fn)
            # decoder_distribution: [batch_size, decoder_len, num_symbols]
            # output_ids_ta: tensorarray: decoder_len [batch_size]
            self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(
                decoder_cell, decoder_fn_inference, scope="decoder_rnn")

            output_len = tf.shape(self.decoder_distribution)[1]  # decoder_len
            output_ids = tf.transpose(
                output_ids_ta.gather(
                    tf.range(output_len)))  # [batch_size, decoder_len]

            # 对output的值域行裁剪,因为存在负值表示用了实体词
            word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols),
                               tf.int64)  # [batch_size, decoder_len]

            # 计算的是实体词在entities中的实际位置 [batch_size, decoder_len]
            # 1、tf.shape(entities_word_embedding)[1] = triple_num*triple_len
            # 2、tf.range(encoder_batch_size): [batch_size]
            # 3、tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]): [batch_size, 1] 实体词在entities中的基地址
            # 4、tf.clip_by_value(-output_ids, 0, num_symbols): [batch_size, decoder_len] 实体词在entities中的偏移量
            # 5、entity_ids: [batch_size, decoder_len] 实体词在entities中的实际位置
            entity_ids = tf.reshape(
                tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(
                    tf.range(encoder_batch_size) *
                    tf.shape(entities_word_embedding)[1], [-1, 1]), [-1])

            # 计算的是所用的实体词 [batch_size, decoder_len]
            # 1、entities: [batch_size, triple_num, triple_len]
            # 2、tf.reshape(self.entities, [-1]): [batch_size*triple_num*triple_len]
            # 3、tf.gather: [batch_size*decoder_len]
            # 4、entities: [batch_size, decoder_len]
            entities = tf.reshape(
                tf.gather(tf.reshape(self.entities, [-1]), entity_ids),
                [-1, output_len])

            words = self.index2symbol.lookup(word_ids)  # 将id转化为实际的词
            # output_ids>0为bool张量,True的位置用words中该位置的词替换
            self.generation = tf.where(output_ids > 0, words, entities)
            self.generation = tf.identity(
                self.generation,
                name='generation')  # [batch_size, decoder_len]
########                                                                                                             ###

# 初始化训练过程
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)

        # 并没有使用衰退的学习率
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)

        # 更新参数的次数
        self.global_step = tf.Variable(0, trainable=False)

        # 要训练的参数
        self.params = tf.global_variables()

        # 选择优化算法
        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)

        self.lr = opt._lr

        # 根据 decoder_loss 计算 params 梯度
        gradients = tf.gradients(self.decoder_loss, self.params)
        # 梯度裁剪
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # 记录损失
        tf.summary.scalar('decoder_loss', self.decoder_loss)
        for each in tf.trainable_variables():
            tf.summary.histogram(each.name, each)  # 记录变量的训练情况
        self.merged_summary_op = tf.summary.merge_all()

        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                    max_to_keep=3,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
        self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                          max_to_keep=1000,
                                          pad_step_number=True)
コード例 #9
0
ファイル: model.py プロジェクト: lolocn/ds-bot
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 num_layers,
                 embed,
                 entity_embed=None,
                 num_entities=0,
                 num_trans_units=100,
                 learning_rate=0.0001,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 num_samples=500,
                 max_length=60,
                 mem_use=True,
                 output_alignments=True,
                 use_lstm=False):

        # 输入数据占位定义
        self.posts = tf.placeholder(tf.string, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None),
                                           'enc_lens')  # batch
        self.responses = tf.placeholder(tf.string, (None, None),
                                        'dec_inps')  # batch*len
        self.responses_length = tf.placeholder(tf.int32, (None),
                                               'dec_lens')  # batch
        self.entities = tf.placeholder(tf.string, (None, None, None),
                                       'entities')  # batch
        self.entity_masks = tf.placeholder(tf.string, (None, None),
                                           'entity_masks')  # batch
        self.triples = tf.placeholder(tf.string, (None, None, None, 3),
                                      'triples')  # batch
        self.posts_triple = tf.placeholder(tf.int32, (None, None, 1),
                                           'enc_triples')  # batch
        self.responses_triple = tf.placeholder(tf.string, (None, None, 3),
                                               'dec_triples')  # batch
        self.match_triples = tf.placeholder(tf.int32, (None, None, None),
                                            'match_triples')  # batch

        encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts))
        triple_num = tf.shape(self.triples)[1]
        triple_len = tf.shape(self.triples)[2]
        one_hot_triples = tf.one_hot(self.match_triples, triple_len)
        use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3])

        # 构建词汇查询talbe (index to string, string to index)
        self.symbol2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=UNK_ID,
                                             shared_name="in_table",
                                             name="in_table",
                                             checkpoint=True)
        self.index2symbol = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_UNK',
                                             shared_name="out_table",
                                             name="out_table",
                                             checkpoint=True)
        self.entity2index = MutableHashTable(key_dtype=tf.string,
                                             value_dtype=tf.int64,
                                             default_value=NONE_ID,
                                             shared_name="entity_in_table",
                                             name="entity_in_table",
                                             checkpoint=True)
        self.index2entity = MutableHashTable(key_dtype=tf.int64,
                                             value_dtype=tf.string,
                                             default_value='_NONE',
                                             shared_name="entity_out_table",
                                             name="entity_out_table",
                                             checkpoint=True)

        self.posts_word_id = self.symbol2index.lookup(self.posts)  # batch*len
        self.posts_entity_id = self.entity2index.lookup(
            self.posts)  # batch*len
        #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6)
        self.responses_target = self.symbol2index.lookup(
            self.responses)  # batch*len

        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        self.responses_word_id = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # 构建词嵌入 table (index to vector)
        if embed is None:
            # 随机初始化词嵌入
            self.embed = tf.get_variable('word_embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            # 使用预训练的词嵌入初始化 (pre-trained word vectors, GloVe or Word2Vec)
            self.embed = tf.get_variable('word_embed',
                                         dtype=tf.float32,
                                         initializer=embed)
        if entity_embed is None:
            # 随机初始化词嵌入
            self.entity_trans = tf.get_variable(
                'entity_embed', [num_entities, num_trans_units],
                tf.float32,
                trainable=False)
        else:
            # 使用预训练的词嵌入初始化 (pre-trained word vectors, GloVe or Word2Vec)
            self.entity_trans = tf.get_variable('entity_embed',
                                                dtype=tf.float32,
                                                initializer=entity_embed,
                                                trainable=False)

        self.entity_trans_transformed = tf.layers.dense(
            self.entity_trans,
            num_trans_units,
            activation=tf.tanh,
            name='trans_transformation')
        padding_entity = tf.get_variable('entity_padding_embed',
                                         [7, num_trans_units],
                                         dtype=tf.float32,
                                         initializer=tf.zeros_initializer())

        self.entity_embed = tf.concat(
            [padding_entity, self.entity_trans_transformed], axis=0)

        triples_embedding = tf.reshape(
            tf.nn.embedding_lookup(self.entity_embed,
                                   self.entity2index.lookup(self.triples)),
            [encoder_batch_size, triple_num, -1, 3 * num_trans_units])
        entities_word_embedding = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entities)),
            [encoder_batch_size, -1, num_embed_units])

        head, relation, tail = tf.split(triples_embedding,
                                        [num_trans_units] * 3,
                                        axis=3)

        # 知识融合层的静态注意力
        with tf.variable_scope('graph_attention'):
            # 拼接head tail
            head_tail = tf.concat([head, tail], axis=3)
            # head tail合成一个向量
            head_tail_transformed = tf.layers.dense(head_tail,
                                                    num_trans_units,
                                                    activation=tf.tanh,
                                                    name='head_tail_transform')
            # relation 向量
            relation_transformed = tf.layers.dense(relation,
                                                   num_trans_units,
                                                   name='relation_transform')
            # relation 和 head_tail 计算注意力权重
            e_weight = tf.reduce_sum(relation_transformed *
                                     head_tail_transformed,
                                     axis=3)
            # 将注意力权重归一化
            alpha_weight = tf.nn.softmax(e_weight)
            # 将权重和head_tail进行加权求和
            graph_embed = tf.reduce_sum(tf.expand_dims(alpha_weight, 3) *
                                        head_tail,
                                        axis=2)

        graph_embed_input = tf.gather_nd(
            graph_embed,
            tf.concat([
                tf.tile(
                    tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32),
                               [-1, 1, 1]), [1, encoder_len, 1]),
                self.posts_triple
            ],
                      axis=2))

        triple_embed_input = tf.reshape(
            tf.nn.embedding_lookup(
                self.entity_embed,
                self.entity2index.lookup(self.responses_triple)),
            [batch_size, decoder_len, 3 * num_trans_units])

        post_word_input = tf.nn.embedding_lookup(
            self.embed, self.posts_word_id)  # batch*len*unit
        response_word_input = tf.nn.embedding_lookup(
            self.embed, self.responses_word_id)  # batch*len*unit

        # 在输入语句中拼接注意力机制计算出来的图谱信息
        self.encoder_input = tf.concat([post_word_input, graph_embed_input],
                                       axis=2)

        # 在输出语句中拼接所有图谱信息
        self.decoder_input = tf.concat(
            [response_word_input, triple_embed_input], axis=2)

        # 编码器使用GRUCell, num_layers为网络层数
        encoder_cell = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])

        # 解码器层使用GRUCell,num_layers为网络层数
        decoder_cell = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])

        # RNN编码器的包装
        encoder_output, encoder_state = dynamic_rnn(encoder_cell,
                                                    self.encoder_input,
                                                    self.posts_length,
                                                    dtype=tf.float32,
                                                    scope="encoder")

        # get output projection function
        output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer(
            num_units, num_symbols, num_samples)
        # 解码器
        with tf.variable_scope('decoder'):
            # 获取 attention 函数
            attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \
                = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)  # 'luong', num_units)

            decoder_fn_train = attention_decoder_fn_train(
                encoder_state,
                attention_keys_init,
                attention_values_init,
                attention_score_fn_init,
                attention_construct_fn_init,
                output_alignments=output_alignments and mem_use,
                max_length=tf.reduce_max(self.responses_length))
            self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(
                decoder_cell,
                decoder_fn_train,
                self.decoder_input,
                self.responses_length,
                scope="decoder_rnn")
            if output_alignments:
                self.alignments = tf.transpose(alignments_ta.stack(),
                                               perm=[1, 0, 2, 3])
                self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(
                    self.decoder_output, self.responses_target,
                    self.decoder_mask, self.alignments, triples_embedding,
                    use_triples, one_hot_triples)
                self.sentence_ppx = tf.identity(self.sentence_ppx,
                                                name='ppx_loss')
            else:
                self.decoder_loss = sequence_loss(self.decoder_output,
                                                  self.responses_target,
                                                  self.decoder_mask)

        with tf.variable_scope('decoder', reuse=True):
            # 获取 attention 函数
            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)  # 'luong', num_units)
            decoder_fn_inference = attention_decoder_fn_inference(
                output_fn,
                encoder_state,
                attention_keys,
                attention_values,
                attention_score_fn,
                attention_construct_fn,
                self.embed,
                GO_ID,
                EOS_ID,
                max_length,
                num_symbols,
                imem=(entities_word_embedding,
                      tf.reshape(
                          triples_embedding,
                          [encoder_batch_size, -1, 3 * num_trans_units])),
                selector_fn=selector_fn)

            self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(
                decoder_cell, decoder_fn_inference, scope="decoder_rnn")

            output_len = tf.shape(self.decoder_distribution)[1]
            output_ids = tf.transpose(
                output_ids_ta.gather(tf.range(output_len)))
            word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols),
                               tf.int64)
            entity_ids = tf.reshape(
                tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(
                    tf.range(encoder_batch_size) *
                    tf.shape(entities_word_embedding)[1], [-1, 1]), [-1])
            entities = tf.reshape(
                tf.gather(tf.reshape(self.entities, [-1]), entity_ids),
                [-1, output_len])
            words = self.index2symbol.lookup(word_ids)

            # 生成用于输出的回复语句
            self.generation = tf.where(output_ids > 0, words, entities)
            self.generation = tf.identity(self.generation, name='generation')

        # 训练参数初始化
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)

        self.params = tf.global_variables()

        # 使用Adam优化器,计算高效、梯度平滑、参数调节简单
        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        self.lr = opt._lr

        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        tf.summary.scalar('decoder_loss', self.decoder_loss)
        for each in tf.trainable_variables():
            tf.summary.histogram(each.name, each)

        self.merged_summary_op = tf.summary.merge_all()

        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                    max_to_keep=3,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
        self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                          max_to_keep=1000,
                                          pad_step_number=True)