Ejemplo n.º 1
0
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 num_layers,
                 vocab=None,
                 embed=None,
                 name_scope=None,
                 learning_rate=0.001,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5,
                 num_samples=512,
                 max_length=30):

        self.posts = tf.placeholder(tf.string, shape=[None,
                                                      None])  # batch * len
        self.posts_length = tf.placeholder(tf.int32, shape=[None])  # batch
        self.responses = tf.placeholder(tf.string, shape=[None,
                                                          None])  # batch*len
        self.responses_length = tf.placeholder(tf.int32, shape=[None])  # batch
        self.weight = tf.placeholder(tf.float32, shape=[None])  # batch

        # build the vocab table (string to index)
        self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(
            self.symbols,
            tf.Variable(
                np.array([i for i in range(num_symbols)], dtype=np.int32),
                False)),
                                      default_value=UNK_ID,
                                      name="symbol2index")

        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.posts_input = self.symbol2index.lookup(
            self.posts)  # batch * utter_len
        self.encoder_input = tf.nn.embedding_lookup(
            self.embed, self.posts_input)  # batch * utter_len * embed_unit

        self.responses_target = self.symbol2index.lookup(
            self.responses)  # batch, len
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        self.responses_input = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # batch, len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])  # batch, len

        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        # Construct multi-layer GRU cells for encoder and decoder
        cell_enc = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])
        cell_dec = MultiRNNCell(
            [GRUCell(num_units) for _ in range(num_layers)])

        # Encode the post sequence
        encoder_output, encoder_state = tf.nn.dynamic_rnn(cell_enc,
                                                          self.encoder_input,
                                                          self.posts_length,
                                                          dtype=tf.float32,
                                                          scope="encoder")

        output_fn, sampled_sequence_loss = output_projection_layer(
            num_units, num_symbols, num_samples)
        attention_keys, attention_values, attention_score_fn, attention_construct_fn \
            = my_attention_decoder_fn.prepare_attention(encoder_output, 'bahdanau', num_units)

        # Decode the response sequence (Training)
        with variable_scope.variable_scope('decoder'):
            decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train(
                encoder_state, attention_keys, attention_values,
                attention_score_fn, attention_construct_fn)
            self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(
                cell_dec,
                decoder_fn_train,
                self.decoder_input,
                self.responses_length,
                scope='decoder_rnn')
            self.decoder_loss = my_loss.sequence_loss(
                self.decoder_output,
                self.responses_target,
                self.decoder_mask,
                softmax_loss_function=sampled_sequence_loss)
            self.weighted_decoder_loss = self.decoder_loss * self.weight

        attention_keys_infer, attention_values_infer, attention_score_fn_infer, attention_construct_fn_infer \
            = my_attention_decoder_fn.prepare_attention(encoder_output, 'bahdanau', num_units, reuse = True)

        # Decode the response sequence (Inference)
        with variable_scope.variable_scope('decoder', reuse=True):
            decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference(
                output_fn, encoder_state, attention_keys_infer,
                attention_values_infer, attention_score_fn_infer,
                attention_construct_fn_infer, self.embed, GO_ID, EOS_ID,
                max_length, num_symbols)
            self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(
                cell_dec, decoder_fn_inference, scope='decoder_rnn')
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, num_symbols - 2],
                         2)[1], 2) + 2  # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols,
                                                     self.generation_index)

        self.params = [
            k for k in tf.trainable_variables() if name_scope in k.name
        ]

        # initialize the training process
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)
        self.adv_global_step = tf.Variable(0, trainable=False)

        # calculate the gradient of parameters
        self.cost = tf.reduce_mean(self.weighted_decoder_loss)
        self.unweighted_cost = tf.reduce_mean(self.decoder_loss)
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.cost, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        all_variables = [
            k for k in tf.global_variables() if name_scope in k.name
        ]
        self.saver = tf.train.Saver(all_variables,
                                    write_version=tf.train.SaverDef.V2,
                                    max_to_keep=5,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
        self.adv_saver = tf.train.Saver(all_variables,
                                        write_version=tf.train.SaverDef.V2,
                                        max_to_keep=5,
                                        pad_step_number=True,
                                        keep_checkpoint_every_n_hours=1.0)
Ejemplo n.º 2
0
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 num_layers,
                 is_train,
                 vocab=None,
                 embed=None,
                 learning_rate=0.1,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 num_samples=512,
                 max_length=30,
                 use_lstm=True):

        self.posts_1 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_2 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_3 = tf.placeholder(tf.string, shape=(None, None))
        self.posts_4 = tf.placeholder(tf.string, shape=(None, None))

        self.entity_1 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_2 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_3 = tf.placeholder(tf.string, shape=(None, None, None, 3))
        self.entity_4 = tf.placeholder(tf.string, shape=(None, None, None, 3))

        self.entity_mask_1 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_2 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_3 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))
        self.entity_mask_4 = tf.placeholder(tf.float32,
                                            shape=(None, None, None))

        self.posts_length_1 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_2 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_3 = tf.placeholder(tf.int32, shape=(None))
        self.posts_length_4 = tf.placeholder(tf.int32, shape=(None))

        self.responses = tf.placeholder(tf.string, shape=(None, None))
        self.responses_length = tf.placeholder(tf.int32, shape=(None))

        self.epoch = tf.Variable(0, trainable=False, name='epoch')
        self.epoch_add_op = self.epoch.assign(self.epoch + 1)

        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.'] * num_symbols),
                                       name="symbols")

        self.symbol2index = HashTable(KeyValueTensorInitializer(
            self.symbols,
            tf.Variable(
                np.array([i for i in range(num_symbols)], dtype=np.int32),
                False)),
                                      default_value=UNK_ID,
                                      name="symbol2index")

        self.posts_input_1 = self.symbol2index.lookup(self.posts_1)

        self.posts_2_target = self.posts_2_embed = self.symbol2index.lookup(
            self.posts_2)
        self.posts_3_target = self.posts_3_embed = self.symbol2index.lookup(
            self.posts_3)
        self.posts_4_target = self.posts_4_embed = self.symbol2index.lookup(
            self.posts_4)

        self.responses_target = self.symbol2index.lookup(self.responses)

        batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape(
            self.responses)[1]

        self.posts_input_2 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_2_embed, [tf.shape(self.posts_2)[1] - 1, 1],
                     1)[0]
        ], 1)
        self.posts_input_3 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_3_embed, [tf.shape(self.posts_3)[1] - 1, 1],
                     1)[0]
        ], 1)
        self.posts_input_4 = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.posts_4_embed, [tf.shape(self.posts_4)[1] - 1, 1],
                     1)[0]
        ], 1)

        self.responses_target = self.symbol2index.lookup(self.responses)

        batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape(
            self.responses)[1]

        self.responses_input = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)

        self.encoder_2_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_2 - 1,
                                 tf.shape(self.posts_2)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_2)[1]])
        self.encoder_3_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_3 - 1,
                                 tf.shape(self.posts_3)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_3)[1]])
        self.encoder_4_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.posts_length_4 - 1,
                                 tf.shape(self.posts_4)[1]),
                      reverse=True,
                      axis=1), [-1, tf.shape(self.posts_4)[1]])

        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        if embed is None:
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input_1 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_1)
        self.encoder_input_2 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_2)
        self.encoder_input_3 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_3)
        self.encoder_input_4 = tf.nn.embedding_lookup(self.embed,
                                                      self.posts_input_4)

        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        entity_embedding_1 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_1)),
            [
                batch_size,
                tf.shape(self.entity_1)[1],
                tf.shape(self.entity_1)[2], 3 * num_embed_units
            ])
        entity_embedding_2 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_2)),
            [
                batch_size,
                tf.shape(self.entity_2)[1],
                tf.shape(self.entity_2)[2], 3 * num_embed_units
            ])
        entity_embedding_3 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_3)),
            [
                batch_size,
                tf.shape(self.entity_3)[1],
                tf.shape(self.entity_3)[2], 3 * num_embed_units
            ])
        entity_embedding_4 = tf.reshape(
            tf.nn.embedding_lookup(self.embed,
                                   self.symbol2index.lookup(self.entity_4)),
            [
                batch_size,
                tf.shape(self.entity_4)[1],
                tf.shape(self.entity_4)[2], 3 * num_embed_units
            ])

        head_1, relation_1, tail_1 = tf.split(entity_embedding_1,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_2, relation_2, tail_2 = tf.split(entity_embedding_2,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_3, relation_3, tail_3 = tf.split(entity_embedding_3,
                                              [num_embed_units] * 3,
                                              axis=3)
        head_4, relation_4, tail_4 = tf.split(entity_embedding_4,
                                              [num_embed_units] * 3,
                                              axis=3)

        with tf.variable_scope('graph_attention'):
            #[batch_size, max_reponse_length, max_triple_num, 2*embed_units]
            head_tail_1 = tf.concat([head_1, tail_1], axis=3)
            #[batch_size, max_reponse_length, max_triple_num, embed_units]
            head_tail_transformed_1 = tf.layers.dense(
                head_tail_1,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            #[batch_size, max_reponse_length, max_triple_num, embed_units]
            relation_transformed_1 = tf.layers.dense(relation_1,
                                                     num_embed_units,
                                                     name='relation_transform')
            #[batch_size, max_reponse_length, max_triple_num]
            e_weight_1 = tf.reduce_sum(relation_transformed_1 *
                                       head_tail_transformed_1,
                                       axis=3)
            #[batch_size, max_reponse_length, max_triple_num]
            alpha_weight_1 = tf.nn.softmax(e_weight_1)
            #[batch_size, max_reponse_length, embed_units]
            graph_embed_1 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_1, 3) *
                (tf.expand_dims(self.entity_mask_1, 3) * head_tail_1),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_2 = tf.concat([head_2, tail_2], axis=3)
            head_tail_transformed_2 = tf.layers.dense(
                head_tail_2,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_2 = tf.layers.dense(relation_2,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_2 = tf.reduce_sum(relation_transformed_2 *
                                       head_tail_transformed_2,
                                       axis=3)
            alpha_weight_2 = tf.nn.softmax(e_weight_2)
            graph_embed_2 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_2, 3) *
                (tf.expand_dims(self.entity_mask_2, 3) * head_tail_2),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_3 = tf.concat([head_3, tail_3], axis=3)
            head_tail_transformed_3 = tf.layers.dense(
                head_tail_3,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_3 = tf.layers.dense(relation_3,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_3 = tf.reduce_sum(relation_transformed_3 *
                                       head_tail_transformed_3,
                                       axis=3)
            alpha_weight_3 = tf.nn.softmax(e_weight_3)
            graph_embed_3 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_3, 3) *
                (tf.expand_dims(self.entity_mask_3, 3) * head_tail_3),
                axis=2)

        with tf.variable_scope('graph_attention', reuse=True):
            head_tail_4 = tf.concat([head_4, tail_4], axis=3)
            head_tail_transformed_4 = tf.layers.dense(
                head_tail_4,
                num_embed_units,
                activation=tf.tanh,
                name='head_tail_transform')
            relation_transformed_4 = tf.layers.dense(relation_4,
                                                     num_embed_units,
                                                     name='relation_transform')
            e_weight_4 = tf.reduce_sum(relation_transformed_4 *
                                       head_tail_transformed_4,
                                       axis=3)
            alpha_weight_4 = tf.nn.softmax(e_weight_4)
            graph_embed_4 = tf.reduce_sum(
                tf.expand_dims(alpha_weight_4, 3) *
                (tf.expand_dims(self.entity_mask_4, 3) * head_tail_4),
                axis=2)

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        output_fn, sampled_sequence_loss = output_projection_layer(
            num_units, num_symbols, num_samples)

        encoder_output_1, encoder_state_1 = dynamic_rnn(cell,
                                                        self.encoder_input_1,
                                                        self.posts_length_1,
                                                        dtype=tf.float32,
                                                        scope="encoder")

        attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1 \
                = attention_decoder_fn.prepare_attention(graph_embed_1, encoder_output_1, 'luong', num_units)
        decoder_fn_train_1 = attention_decoder_fn.attention_decoder_fn_train(
            encoder_state_1,
            attention_keys_1,
            attention_values_1,
            attention_score_fn_1,
            attention_construct_fn_1,
            max_length=tf.reduce_max(self.posts_length_2))
        encoder_output_2, encoder_state_2, alignments_ta_2 = dynamic_rnn_decoder(
            cell,
            decoder_fn_train_1,
            self.encoder_input_2,
            self.posts_length_2,
            scope="decoder")
        self.alignments_2 = tf.transpose(alignments_ta_2.stack(),
                                         perm=[1, 0, 2])

        self.decoder_loss_2 = sampled_sequence_loss(encoder_output_2,
                                                    self.posts_2_target,
                                                    self.encoder_2_mask)

        with variable_scope.variable_scope('', reuse=True):
            attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2 \
                    = attention_decoder_fn.prepare_attention(graph_embed_2, encoder_output_2, 'luong', num_units)
            decoder_fn_train_2 = attention_decoder_fn.attention_decoder_fn_train(
                encoder_state_2,
                attention_keys_2,
                attention_values_2,
                attention_score_fn_2,
                attention_construct_fn_2,
                max_length=tf.reduce_max(self.posts_length_3))
            encoder_output_3, encoder_state_3, alignments_ta_3 = dynamic_rnn_decoder(
                cell,
                decoder_fn_train_2,
                self.encoder_input_3,
                self.posts_length_3,
                scope="decoder")
            self.alignments_3 = tf.transpose(alignments_ta_3.stack(),
                                             perm=[1, 0, 2])

            self.decoder_loss_3 = sampled_sequence_loss(
                encoder_output_3, self.posts_3_target, self.encoder_3_mask)

            attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3 \
                    = attention_decoder_fn.prepare_attention(graph_embed_3, encoder_output_3, 'luong', num_units)
            decoder_fn_train_3 = attention_decoder_fn.attention_decoder_fn_train(
                encoder_state_3,
                attention_keys_3,
                attention_values_3,
                attention_score_fn_3,
                attention_construct_fn_3,
                max_length=tf.reduce_max(self.posts_length_4))
            encoder_output_4, encoder_state_4, alignments_ta_4 = dynamic_rnn_decoder(
                cell,
                decoder_fn_train_3,
                self.encoder_input_4,
                self.posts_length_4,
                scope="decoder")
            self.alignments_4 = tf.transpose(alignments_ta_4.stack(),
                                             perm=[1, 0, 2])

            self.decoder_loss_4 = sampled_sequence_loss(
                encoder_output_4, self.posts_4_target, self.encoder_4_mask)

            attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                    = attention_decoder_fn.prepare_attention(graph_embed_4, encoder_output_4, 'luong', num_units)

        if is_train:
            with variable_scope.variable_scope('', reuse=True):
                decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
                    encoder_state_4,
                    attention_keys,
                    attention_values,
                    attention_score_fn,
                    attention_construct_fn,
                    max_length=tf.reduce_max(self.responses_length))
                self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(
                    cell,
                    decoder_fn_train,
                    self.decoder_input,
                    self.responses_length,
                    scope="decoder")
                self.alignments = tf.transpose(alignments_ta.stack(),
                                               perm=[1, 0, 2])

                self.decoder_loss = sampled_sequence_loss(
                    self.decoder_output, self.responses_target,
                    self.decoder_mask)

            self.params = tf.trainable_variables()

            self.learning_rate = tf.Variable(float(learning_rate),
                                             trainable=False,
                                             dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                self.learning_rate * learning_rate_decay_factor)
            self.global_step = tf.Variable(0, trainable=False)

            #opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9)

            gradients = tf.gradients(
                self.decoder_loss + self.decoder_loss_2 + self.decoder_loss_3 +
                self.decoder_loss_4, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
                gradients, max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients,
                                                  self.params),
                                              global_step=self.global_step)

        else:
            with variable_scope.variable_scope('', reuse=True):
                decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(
                    output_fn, encoder_state_4, attention_keys,
                    attention_values, attention_score_fn,
                    attention_construct_fn, self.embed, GO_ID, EOS_ID,
                    max_length, num_symbols)
                self.decoder_distribution, _, alignments_ta = dynamic_rnn_decoder(
                    cell, decoder_fn_inference, scope="decoder")
                output_len = tf.shape(self.decoder_distribution)[1]
                self.alignments = tf.transpose(
                    alignments_ta.gather(tf.range(output_len)), [1, 0, 2])

            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, num_symbols - 2],
                         2)[1], 2) + 2  # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols,
                                                     self.generation_index,
                                                     name="generation")

            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2,
                                    max_to_keep=10,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
class Model(BaseModel):
    def __init__(self, sess, config, api, log_dir, forward, scope=None):

        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)

        self.topic_vocab = api.topic_vocab
        self.topic_vocab_size = len(self.topic_vocab)

        self.da_vocab = api.dialog_act_vocab
        self.da_vocab_size = len(self.da_vocab)

        self.sess = sess
        self.scope = scope

        self.pad_id = self.rev_vocab["<pad>"]
        self.sos_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.unk_id = self.rev_vocab["<unk>"]

        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.latent_size = config.latent_size

        with tf.name_scope("io"):

            self.input_contexts = tf.placeholder(dtype=tf.string,
                                                 shape=(None, None, None),
                                                 name="dialog_context")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")
            self.topics = tf.placeholder(dtype=tf.int32,
                                         shape=(None, ),
                                         name="topics")

            self.output_tokens = tf.placeholder(dtype=tf.string,
                                                shape=(None, None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, None),
                                              name="output_lens")

            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        batch_size = tf.shape(self.input_contexts)[0]
        max_dialog_len = tf.shape(self.input_contexts)[1]
        max_out_len = tf.shape(self.output_tokens)[2]

        with tf.variable_scope("tokenization"):
            self.symbols = tf.Variable(self.vocab,
                                       trainable=False,
                                       name="symbols")
            self.symbol2index = HashTable(KeyValueTensorInitializer(
                self.symbols,
                tf.Variable(
                    np.array([i for i in range(self.vocab_size)],
                             dtype=np.int32), False)),
                                          default_value=self.unk_id,
                                          name="symbol2index")

            self.contexts = self.symbol2index.lookup(self.input_contexts)
            self.responses_target = self.symbol2index.lookup(
                self.output_tokens)

        with tf.variable_scope("topic_embedding"):
            t_embedding = tf.get_variable(
                "embedding", [self.topic_vocab_size, config.topic_embed_size],
                dtype=tf.float32)
            topic_embedding = tf.nn.embedding_lookup(t_embedding, self.topics)
            # [batch_size, topic_embed_size]

        with tf.variable_scope("word_embedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = tf.nn.embedding_lookup(
                embedding, tf.reshape(self.contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding,
                [batch_size * max_dialog_len, -1, config.embed_size])
            output_embedding = tf.nn.embedding_lookup(
                embedding, tf.reshape(self.responses_target, [-1]))
            output_embedding = tf.reshape(
                output_embedding,
                [batch_size * max_dialog_len, -1, config.embed_size])

        with tf.variable_scope("uttrance_encoder"):

            if config.sent_type == "rnn":
                sent_cell = self.create_rnn_cell(self.sent_cell_size)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     tf.reshape(
                                                         self.output_lens,
                                                         [-1]),
                                                     scope="sent_rnn",
                                                     reuse=True)

            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.create_rnn_cell(self.sent_cell_size)
                bwd_sent_cell = self.create_rnn_cell(self.sent_cell_size)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        tf.reshape(
                                                            self.output_lens,
                                                            [-1]),
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [rnn, bi_rnn]")

            input_embedding = tf.reshape(
                input_embedding, [batch_size, max_dialog_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

            output_embedding = tf.reshape(
                output_embedding, [batch_size, max_dialog_len, sent_size])

        with tf.variable_scope("context_encoder"):

            enc_cell = self.create_rnn_cell(self.context_cell_size)

            cxt_outputs, _ = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)
            # [batch_size, max_dialog_len, context_cell_size]

        tile_topic_embedding = tf.reshape(
            tf.tile(topic_embedding, [1, max_dialog_len]),
            [batch_size, max_dialog_len, config.topic_embed_size])
        cond_embedding = tf.concat([tile_topic_embedding, cxt_outputs], -1)
        # [batch_size, max_dialog_len, context_cell_size + topic_embed_size]

        with tf.variable_scope("posterior_network"):
            recog_input = tf.concat([cond_embedding, output_embedding], -1)
            post_sample, recog_mu_1, recog_logvar_1, recog_mu_2, recog_logvar_2 = self.hierarchical_inference_net(
                recog_input)

        with tf.variable_scope("prior_network"):
            prior_input = cond_embedding
            prior_sample, prior_mu_1, prior_logvar_1, prior_mu_2, prior_logvar_2 = self.hierarchical_inference_net(
                prior_input)

        latent_sample = tf.cond(self.use_prior, lambda: prior_sample,
                                lambda: post_sample)

        with tf.variable_scope("decoder"):

            dec_inputs = tf.concat([cond_embedding, latent_sample], -1)
            dec_inputs_dim = config.latent_size + config.topic_embed_size + self.context_cell_size
            dec_inputs = tf.reshape(
                dec_inputs, [batch_size * max_dialog_len, dec_inputs_dim])

            dec_init_state = tf.contrib.layers.fully_connected(
                dec_inputs,
                self.dec_cell_size,
                activation_fn=None,
                scope="init_state")
            dec_cell = self.create_rnn_cell(self.dec_cell_size)

            output_fn, sampled_sequence_loss = output_projection_layer(
                self.dec_cell_size, self.vocab_size)
            decoder_fn_train = decoder_fn.simple_decoder_fn_train(
                dec_init_state, dec_inputs)
            decoder_fn_inference = decoder_fn.simple_decoder_fn_inference(
                output_fn, dec_init_state, dec_inputs, embedding, self.sos_id,
                self.eos_id, max_out_len * 2, self.vocab_size)

            if forward:
                dec_outs, _, final_context_state = dynamic_rnn_decoder(
                    dec_cell, decoder_fn_inference, scope="decoder")
            else:
                dec_input_embedding = tf.nn.embedding_lookup(
                    embedding, tf.reshape(self.responses_target, [-1]))
                dec_input_embedding = tf.reshape(
                    dec_input_embedding,
                    [batch_size * max_dialog_len, -1, config.embed_size])
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = tf.reshape(self.output_lens, [-1]) - 1

                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform(
                            (batch_size * max_dialog_len, max_out_len - 1),
                            minval=0.0,
                            maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

                dec_outs, _, final_context_state = dynamic_rnn_decoder(
                    dec_cell,
                    decoder_fn_train,
                    dec_input_embedding,
                    dec_seq_lens,
                    scope="decoder")

                reshape_target = tf.reshape(self.responses_target,
                                            [batch_size * max_dialog_len, -1])
                labels = reshape_target[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))
                local_loss = sampled_sequence_loss(dec_outs, labels,
                                                   label_mask)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:tf.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                dec_out_words = tf.argmax(dec_outs, 2)

            self.dec_out_words = tf.reshape(
                dec_out_words, [batch_size, max_dialog_len, -1])[:, -1, :]

        if not forward:
            with tf.variable_scope("loss"):

                self.avg_rc_loss = tf.reduce_mean(local_loss)
                self.rc_ppl = tf.reduce_sum(local_loss)
                self.total_word = tf.reduce_sum(label_mask)

                new_recog_mu_2 = tf.reshape(recog_mu_2,
                                            [-1, config.latent_size])
                new_recog_logvar_2 = tf.reshape(recog_logvar_2,
                                                [-1, config.latent_size])
                new_prior_mu_1 = tf.reshape(prior_mu_1,
                                            [-1, config.latent_size])
                new_prior_logvar_1 = tf.reshape(prior_logvar_1,
                                                [-1, config.latent_size])
                new_recog_mu_1 = tf.reshape(recog_mu_1,
                                            [-1, config.latent_size])
                new_recog_logvar_1 = tf.reshape(recog_logvar_1,
                                                [-1, config.latent_size])
                new_prior_mu_2 = tf.reshape(prior_mu_2,
                                            [-1, config.latent_size])
                new_prior_logvar_2 = tf.reshape(prior_logvar_2,
                                                [-1, config.latent_size])

                kld_1 = gaussian_kld(new_recog_mu_2, new_recog_logvar_2,
                                     new_prior_mu_1, new_prior_logvar_1)
                kld_2 = gaussian_kld(new_recog_mu_1, new_recog_logvar_1,
                                     new_prior_mu_2, new_prior_logvar_2)
                kld = kld_1 + kld_2

                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    self.kl_w = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    self.kl_w = tf.constant(1.0)

                aug_elbo = self.elbo = self.avg_rc_loss + self.kl_w * self.avg_kld

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                self.summary_op = tf.summary.merge_all()
                """
				self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar)
				self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar)
				self.est_marginal = tf.reduce_mean(- self.log_p_z + self.log_q_z_xy)
				"""

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)

    def batch_2_feed(self, batch, global_t, use_prior, repeat=1):

        context, context_lens, floors, topics, my_profiles, ot_profiles, outputs, output_lens, output_das = batch

        feed_dict = {
            self.input_contexts: context,
            self.context_lens: context_lens,
            self.topics: topics,
            self.output_tokens: outputs,
            self.output_lens: output_lens,
            self.use_prior: use_prior
        }

        if repeat > 1:
            tiled_feed_dict = {}
            for key, val in feed_dict.items():
                if key is self.use_prior:
                    tiled_feed_dict[key] = val
                    continue
                multipliers = [1] * len(val.shape)
                multipliers[0] = repeat
                tiled_feed_dict[key] = np.tile(val, multipliers)
            feed_dict = tiled_feed_dict

        if global_t is not None:
            feed_dict[self.global_t] = global_t

        return feed_dict

    def train(self, global_t, sess, train_feed, update_limit=5000):
        elbo_losses = []
        rc_losses = []

        rc_ppls = []
        total_words = []

        kl_losses = []
        local_t = 0
        start_time = time.time()
        loss_names = ["elbo_loss", "rc_loss", "kl_loss"]
        while True:
            batch = train_feed.next_new_batch()
            if batch is None:
                break
            if update_limit is not None and local_t >= update_limit:
                break
            feed_dict = self.batch_2_feed(batch, global_t, use_prior=False)
            _, sum_op, elbo_loss, rc_loss, rc_ppl, kl_loss, total_word = sess.run(
                [
                    self.train_ops, self.summary_op, self.elbo,
                    self.avg_rc_loss, self.rc_ppl, self.avg_kld,
                    self.total_word
                ], feed_dict)
            self.train_summary_writer.add_summary(sum_op, global_t)

            total_words.append(total_word)
            elbo_losses.append(elbo_loss)
            rc_ppls.append(rc_ppl)
            rc_losses.append(rc_loss)
            kl_losses.append(kl_loss)

            global_t += 1
            local_t += 1
            if local_t % (train_feed.num_batch / 20) == 0:
                kl_w = sess.run(self.kl_w, {self.global_t: global_t})
                self.print_loss(
                    "%.2f" % (train_feed.ptr / float(train_feed.num_batch)),
                    loss_names, [elbo_losses, rc_losses, kl_losses],
                    "kl_w %f, perplexity: %f" %
                    (kl_w, np.exp(np.sum(rc_ppls) / np.sum(total_words))))

        # finish epoch!
        epoch_time = time.time() - start_time
        avg_losses = self.print_loss(
            "Epoch Done", loss_names, [elbo_losses, rc_losses, kl_losses],
            "step time %.4f, perplexity: %f" %
            (epoch_time / train_feed.num_batch,
             np.exp(np.sum(rc_ppls) / np.sum(total_words))))

        return global_t, avg_losses[0]

    def valid(self, name, sess, valid_feed):
        elbo_losses = []
        rc_losses = []
        rc_ppls = []
        kl_losses = []
        total_words = []

        while True:
            batch = valid_feed.next_new_batch()
            if batch is None:
                break
            feed_dict = self.batch_2_feed(batch,
                                          None,
                                          use_prior=False,
                                          repeat=1)

            elbo_loss, rc_loss, rc_ppl, kl_loss, total_word = sess.run([
                self.elbo, self.avg_rc_loss, self.rc_ppl, self.avg_kld,
                self.total_word
            ], feed_dict)

            total_words.append(total_word)
            elbo_losses.append(elbo_loss)
            rc_losses.append(rc_loss)
            rc_ppls.append(rc_ppl)
            kl_losses.append(kl_loss)

        avg_losses = self.print_loss(
            name, ["elbo_loss", "rc_loss", "kl_loss"],
            [elbo_losses, rc_losses, kl_losses],
            "perplexity: %f" % np.exp(np.sum(rc_ppls) / np.sum(total_words)))
        return avg_losses[0]

    def test(self, sess, test_feed, num_batch=None, repeat=5, dest=sys.stdout):

        local_t = 0
        recall_bleus = []
        prec_bleus = []

        while True:
            batch = test_feed.next_new_batch()
            if batch is None or (num_batch is not None
                                 and local_t > num_batch):
                break
            feed_dict = self.batch_2_feed(batch,
                                          None,
                                          use_prior=True,
                                          repeat=repeat)
            word_outs = sess.run(self.dec_out_words, feed_dict)

            sample_words = np.split(word_outs, repeat, axis=0)

            true_srcs = feed_dict[self.input_contexts]
            true_src_lens = feed_dict[self.context_lens]
            true_outs = feed_dict[self.output_tokens][:, -1, :]
            true_topics = feed_dict[self.topics]
            local_t += 1

            if dest != sys.stdout:
                if local_t % (test_feed.num_batch / 10) == 0:
                    print("%.2f >> " %
                          (test_feed.ptr / float(test_feed.num_batch))),

            for b_id in range(test_feed.batch_size):
                dest.write(
                    "Batch %d index %d of topic %s\n" %
                    (local_t, b_id, self.topic_vocab[true_topics[b_id]]))

                start = np.maximum(0, true_src_lens[b_id] - 5)
                for t_id in range(start, true_srcs.shape[1], 1):
                    src_str = " ".join([
                        w for w in true_srcs[b_id, t_id].tolist()
                        if w not in ["<pad>"]
                    ])
                    dest.write("Src %d: %s\n" % (t_id, src_str))

                true_tokens = [
                    w for w in true_outs[b_id].tolist()
                    if w not in ["<pad>", "<s>", "</s>"]
                ]
                true_str = " ".join(true_tokens).replace(" ' ", "'")
                dest.write("Target >> %s\n" % (true_str))

                local_tokens = []
                for r_id in range(repeat):
                    pred_outs = sample_words[r_id]
                    # pred_da = np.argmax(sample_das[r_id], axis=1)[0]
                    pred_tokens = [
                        self.vocab[e] for e in pred_outs[b_id].tolist()
                        if e not in [self.eos_id, self.pad_id, self.sos_id]
                    ]
                    pred_str = " ".join(pred_tokens).replace(" ' ", "'")
                    dest.write("Sample %d >> %s\n" % (r_id, pred_str))
                    local_tokens.append(pred_tokens)

                max_bleu, avg_bleu = utils.get_bleu_stats(
                    true_tokens, local_tokens)
                recall_bleus.append(max_bleu)
                prec_bleus.append(avg_bleu)
                dest.write("\n")

        avg_recall_bleu = float(np.mean(recall_bleus))
        avg_prec_bleu = float(np.mean(prec_bleus))
        avg_f1 = 2 * (avg_prec_bleu * avg_recall_bleu) / (
            avg_prec_bleu + avg_recall_bleu + 10e-12)
        report = "Avg recall BLEU %f, avg precision BLEU %f and F1 %f (only 1 reference response. Not final result)" \
           % (avg_recall_bleu, avg_prec_bleu, avg_f1)
        print report
        dest.write(report + "\n")
        print("Done testing")

    def hierarchical_inference_net(self, inputs):

        num_group = 2
        group_dim = int(self.latent_size / 2)

        recog_mulogvar_1 = tf.contrib.layers.fully_connected(
            inputs, group_dim * 2, activation_fn=None, scope="muvar")
        recog_mu_1, recog_logvar_1 = tf.split(recog_mulogvar_1, 2, axis=-1)
        z_post_1 = sample_gaussian(recog_mu_1, recog_logvar_1)

        cont_inputs = tf.concat([z_post_1, inputs], -1)
        recog_mulogvar_2 = tf.contrib.layers.fully_connected(
            cont_inputs, group_dim * 2, activation_fn=None, scope="muvar1")
        recog_mu_2, recog_logvar_2 = tf.split(recog_mulogvar_2, 2, axis=-1)
        z_post_2 = sample_gaussian(recog_mu_2, recog_logvar_2)

        z_post = tf.concat([z_post_1, z_post_2], -1)

        return z_post, recog_mu_1, recog_logvar_1, recog_mu_2, recog_logvar_2
    def __init__(self, sess, config, api, log_dir, forward, scope=None):

        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)

        self.topic_vocab = api.topic_vocab
        self.topic_vocab_size = len(self.topic_vocab)

        self.da_vocab = api.dialog_act_vocab
        self.da_vocab_size = len(self.da_vocab)

        self.sess = sess
        self.scope = scope

        self.pad_id = self.rev_vocab["<pad>"]
        self.sos_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.unk_id = self.rev_vocab["<unk>"]

        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.latent_size = config.latent_size

        with tf.name_scope("io"):

            self.input_contexts = tf.placeholder(dtype=tf.string,
                                                 shape=(None, None, None),
                                                 name="dialog_context")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")
            self.topics = tf.placeholder(dtype=tf.int32,
                                         shape=(None, ),
                                         name="topics")

            self.output_tokens = tf.placeholder(dtype=tf.string,
                                                shape=(None, None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, None),
                                              name="output_lens")

            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        batch_size = tf.shape(self.input_contexts)[0]
        max_dialog_len = tf.shape(self.input_contexts)[1]
        max_out_len = tf.shape(self.output_tokens)[2]

        with tf.variable_scope("tokenization"):
            self.symbols = tf.Variable(self.vocab,
                                       trainable=False,
                                       name="symbols")
            self.symbol2index = HashTable(KeyValueTensorInitializer(
                self.symbols,
                tf.Variable(
                    np.array([i for i in range(self.vocab_size)],
                             dtype=np.int32), False)),
                                          default_value=self.unk_id,
                                          name="symbol2index")

            self.contexts = self.symbol2index.lookup(self.input_contexts)
            self.responses_target = self.symbol2index.lookup(
                self.output_tokens)

        with tf.variable_scope("topic_embedding"):
            t_embedding = tf.get_variable(
                "embedding", [self.topic_vocab_size, config.topic_embed_size],
                dtype=tf.float32)
            topic_embedding = tf.nn.embedding_lookup(t_embedding, self.topics)
            # [batch_size, topic_embed_size]

        with tf.variable_scope("word_embedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = tf.nn.embedding_lookup(
                embedding, tf.reshape(self.contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding,
                [batch_size * max_dialog_len, -1, config.embed_size])
            output_embedding = tf.nn.embedding_lookup(
                embedding, tf.reshape(self.responses_target, [-1]))
            output_embedding = tf.reshape(
                output_embedding,
                [batch_size * max_dialog_len, -1, config.embed_size])

        with tf.variable_scope("uttrance_encoder"):

            if config.sent_type == "rnn":
                sent_cell = self.create_rnn_cell(self.sent_cell_size)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     tf.reshape(
                                                         self.output_lens,
                                                         [-1]),
                                                     scope="sent_rnn",
                                                     reuse=True)

            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.create_rnn_cell(self.sent_cell_size)
                bwd_sent_cell = self.create_rnn_cell(self.sent_cell_size)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        tf.reshape(
                                                            self.output_lens,
                                                            [-1]),
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [rnn, bi_rnn]")

            input_embedding = tf.reshape(
                input_embedding, [batch_size, max_dialog_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

            output_embedding = tf.reshape(
                output_embedding, [batch_size, max_dialog_len, sent_size])

        with tf.variable_scope("context_encoder"):

            enc_cell = self.create_rnn_cell(self.context_cell_size)

            cxt_outputs, _ = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)
            # [batch_size, max_dialog_len, context_cell_size]

        tile_topic_embedding = tf.reshape(
            tf.tile(topic_embedding, [1, max_dialog_len]),
            [batch_size, max_dialog_len, config.topic_embed_size])
        cond_embedding = tf.concat([tile_topic_embedding, cxt_outputs], -1)
        # [batch_size, max_dialog_len, context_cell_size + topic_embed_size]

        with tf.variable_scope("posterior_network"):
            recog_input = tf.concat([cond_embedding, output_embedding], -1)
            post_sample, recog_mu_1, recog_logvar_1, recog_mu_2, recog_logvar_2 = self.hierarchical_inference_net(
                recog_input)

        with tf.variable_scope("prior_network"):
            prior_input = cond_embedding
            prior_sample, prior_mu_1, prior_logvar_1, prior_mu_2, prior_logvar_2 = self.hierarchical_inference_net(
                prior_input)

        latent_sample = tf.cond(self.use_prior, lambda: prior_sample,
                                lambda: post_sample)

        with tf.variable_scope("decoder"):

            dec_inputs = tf.concat([cond_embedding, latent_sample], -1)
            dec_inputs_dim = config.latent_size + config.topic_embed_size + self.context_cell_size
            dec_inputs = tf.reshape(
                dec_inputs, [batch_size * max_dialog_len, dec_inputs_dim])

            dec_init_state = tf.contrib.layers.fully_connected(
                dec_inputs,
                self.dec_cell_size,
                activation_fn=None,
                scope="init_state")
            dec_cell = self.create_rnn_cell(self.dec_cell_size)

            output_fn, sampled_sequence_loss = output_projection_layer(
                self.dec_cell_size, self.vocab_size)
            decoder_fn_train = decoder_fn.simple_decoder_fn_train(
                dec_init_state, dec_inputs)
            decoder_fn_inference = decoder_fn.simple_decoder_fn_inference(
                output_fn, dec_init_state, dec_inputs, embedding, self.sos_id,
                self.eos_id, max_out_len * 2, self.vocab_size)

            if forward:
                dec_outs, _, final_context_state = dynamic_rnn_decoder(
                    dec_cell, decoder_fn_inference, scope="decoder")
            else:
                dec_input_embedding = tf.nn.embedding_lookup(
                    embedding, tf.reshape(self.responses_target, [-1]))
                dec_input_embedding = tf.reshape(
                    dec_input_embedding,
                    [batch_size * max_dialog_len, -1, config.embed_size])
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = tf.reshape(self.output_lens, [-1]) - 1

                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform(
                            (batch_size * max_dialog_len, max_out_len - 1),
                            minval=0.0,
                            maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

                dec_outs, _, final_context_state = dynamic_rnn_decoder(
                    dec_cell,
                    decoder_fn_train,
                    dec_input_embedding,
                    dec_seq_lens,
                    scope="decoder")

                reshape_target = tf.reshape(self.responses_target,
                                            [batch_size * max_dialog_len, -1])
                labels = reshape_target[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))
                local_loss = sampled_sequence_loss(dec_outs, labels,
                                                   label_mask)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:tf.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                dec_out_words = tf.argmax(dec_outs, 2)

            self.dec_out_words = tf.reshape(
                dec_out_words, [batch_size, max_dialog_len, -1])[:, -1, :]

        if not forward:
            with tf.variable_scope("loss"):

                self.avg_rc_loss = tf.reduce_mean(local_loss)
                self.rc_ppl = tf.reduce_sum(local_loss)
                self.total_word = tf.reduce_sum(label_mask)

                new_recog_mu_2 = tf.reshape(recog_mu_2,
                                            [-1, config.latent_size])
                new_recog_logvar_2 = tf.reshape(recog_logvar_2,
                                                [-1, config.latent_size])
                new_prior_mu_1 = tf.reshape(prior_mu_1,
                                            [-1, config.latent_size])
                new_prior_logvar_1 = tf.reshape(prior_logvar_1,
                                                [-1, config.latent_size])
                new_recog_mu_1 = tf.reshape(recog_mu_1,
                                            [-1, config.latent_size])
                new_recog_logvar_1 = tf.reshape(recog_logvar_1,
                                                [-1, config.latent_size])
                new_prior_mu_2 = tf.reshape(prior_mu_2,
                                            [-1, config.latent_size])
                new_prior_logvar_2 = tf.reshape(prior_logvar_2,
                                                [-1, config.latent_size])

                kld_1 = gaussian_kld(new_recog_mu_2, new_recog_logvar_2,
                                     new_prior_mu_1, new_prior_logvar_1)
                kld_2 = gaussian_kld(new_recog_mu_1, new_recog_logvar_1,
                                     new_prior_mu_2, new_prior_logvar_2)
                kld = kld_1 + kld_2

                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    self.kl_w = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    self.kl_w = tf.constant(1.0)

                aug_elbo = self.elbo = self.avg_rc_loss + self.kl_w * self.avg_kld

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                self.summary_op = tf.summary.merge_all()
                """
				self.log_p_z = norm_log_liklihood(latent_sample, prior_mu, prior_logvar)
				self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu, recog_logvar)
				self.est_marginal = tf.reduce_mean(- self.log_p_z + self.log_q_z_xy)
				"""

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
Ejemplo n.º 5
0
    def __init__(self,
            num_symbols,
            num_qwords, #modify
            num_embed_units,
            num_units,
            num_layers,
            is_train,
            vocab=None,
            embed=None,
            question_data=True,
            learning_rate=0.5,
            learning_rate_decay_factor=0.95,
            max_gradient_norm=5.0,
            num_samples=512,
            max_length=30,
            use_lstm=False):

        self.posts = tf.placeholder(tf.string, shape=(None, None))  # batch*len
        self.posts_length = tf.placeholder(tf.int32, shape=(None))  # batch
        self.responses = tf.placeholder(tf.string, shape=(None, None))  # batch*len
        self.responses_length = tf.placeholder(tf.int32, shape=(None))  # batch
        self.keyword_tensor = tf.placeholder(tf.float32, shape=(None, 3, None)) #(batch * len) * 3 * numsymbol
        self.word_type = tf.placeholder(tf.int32, shape=(None))   #(batch * len)

        # build the vocab table (string to index)
        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols,
            tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)),
            default_value=UNK_ID, name="symbol2index")
        self.posts_input = self.symbol2index.lookup(self.posts)   # batch*len
        self.responses_target = self.symbol2index.lookup(self.responses)   #batch*len
        
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1]
        self.responses_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID,
            tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1)   # batch*len
        #delete the last column of responses_target) and add 'GO at the front of it.
        self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1,
            decoder_len), reverse=True, axis=1), [-1, decoder_len]) # bacth * len

        print "embedding..."
        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32)
        else:
            print len(vocab), len(embed), len(embed[0])
            print embed
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_input) #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input)

        print "embedding finished"

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        # rnn encoder
        encoder_output, encoder_state = dynamic_rnn(cell, self.encoder_input,
                self.posts_length, dtype=tf.float32, scope="encoder")
        # get output projection function
        output_fn, sampled_sequence_loss = output_projection_layer(num_units,
                num_symbols, num_qwords, num_samples, question_data)

        print "encoder_output.shape:", encoder_output.get_shape()

        # get attention function
        attention_keys, attention_values, attention_score_fn, attention_construct_fn \
              = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units)

        # get decoding loop function
        decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(encoder_state,
                attention_keys, attention_values, attention_score_fn, attention_construct_fn)
        decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(output_fn,
                self.keyword_tensor,
                encoder_state, attention_keys, attention_values, attention_score_fn,
                attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols)

        if is_train:
            # rnn decoder
            self.decoder_output, _, _ = dynamic_rnn_decoder(cell, decoder_fn_train,
                    self.decoder_input, self.responses_length, scope="decoder")
            # calculate the loss of decoder
            # self.decoder_output = tf.Print(self.decoder_output, [self.decoder_output])
            self.decoder_loss, self.log_perplexity = sampled_sequence_loss(self.decoder_output,
                    self.responses_target, self.decoder_mask, self.keyword_tensor, self.word_type)

            # building graph finished and get all parameters
            self.params = tf.trainable_variables()

            for item in tf.trainable_variables():
                print item.name, item.get_shape()

            # initialize the training process
            self.learning_rate = tf.Variable(float(learning_rate), trainable=False,
                    dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                    self.learning_rate * learning_rate_decay_factor)

            self.global_step = tf.Variable(0, trainable=False)

            # calculate the gradient of parameters

            opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            gradients = tf.gradients(self.decoder_loss, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients,
                    max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                    global_step=self.global_step)

        else:
            # rnn decoder
            self.decoder_distribution, _, _ = dynamic_rnn_decoder(cell, decoder_fn_inference,
                    scope="decoder")
            print("self.decoder_distribution.shape():",self.decoder_distribution.get_shape())
            self.decoder_distribution = tf.Print(self.decoder_distribution, ["distribution.shape()", tf.reduce_sum(self.decoder_distribution)])
            # generating the response
            self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
                [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index)

            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2,
                max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
Ejemplo n.º 6
0
    def __init__(self,
            num_symbols,
            num_embed_units,
            num_units,
            is_train,
            vocab=None,
            content_pos=None,
            rhetoric_pos = None,
            embed=None,
            learning_rate=0.1,
            learning_rate_decay_factor=0.9995,
            max_gradient_norm=5.0,
            max_length=30,
            latent_size=128,
            use_lstm=False,
            num_classes=3,
            full_kl_step=80000,
            mem_slot_num=4,
            mem_size=128):
        
        self.ori_sents = tf.placeholder(tf.string, shape=(None, None))
        self.ori_sents_length = tf.placeholder(tf.int32, shape=(None))
        self.rep_sents = tf.placeholder(tf.string, shape=(None, None))
        self.rep_sents_length = tf.placeholder(tf.int32, shape=(None))
        self.labels = tf.placeholder(tf.float32, shape=(None, num_classes))
        self.use_prior = tf.placeholder(tf.bool)
        self.global_t = tf.placeholder(tf.int32)
        self.content_mask = tf.reduce_sum(tf.one_hot(content_pos, num_symbols, 1.0, 0.0), axis = 0)
        self.rhetoric_mask = tf.reduce_sum(tf.one_hot(rhetoric_pos, num_symbols, 1.0, 0.0), axis = 0)

        topic_memory = tf.zeros(name="topic_memory", dtype=tf.float32,
                                  shape=[None, mem_slot_num, mem_size])

        w_topic_memory = tf.get_variable(name="w_topic_memory", dtype=tf.float32,
                                    initializer=tf.random_uniform([mem_size, mem_size], -0.1, 0.1))

        # build the vocab table (string to index)
        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, 
            tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), 
            default_value=UNK_ID, name="symbol2index")

        self.ori_sents_input = self.symbol2index.lookup(self.ori_sents)
        self.rep_sents_target = self.symbol2index.lookup(self.rep_sents)
        batch_size, decoder_len = tf.shape(self.rep_sents)[0], tf.shape(self.rep_sents)[1]
        self.rep_sents_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID,
            tf.split(self.rep_sents_target, [decoder_len-1, 1], 1)[0]], 1)
        self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.rep_sents_length-1,
            decoder_len), reverse=True, axis=1), [-1, decoder_len])        
        
        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed)

        self.pattern_embed = tf.get_variable('pattern_embed', [num_classes, num_embed_units], tf.float32)
        
        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.ori_sents_input)
        self.decoder_input = tf.nn.embedding_lookup(self.embed, self.rep_sents_input)

        if use_lstm:
            cell_fw = LSTMCell(num_units)
            cell_bw = LSTMCell(num_units)
            cell_dec = LSTMCell(2*num_units)
        else:
            cell_fw = GRUCell(num_units)
            cell_bw = GRUCell(num_units)
            cell_dec = GRUCell(2*num_units)

        # origin sentence encoder
        with variable_scope.variable_scope("encoder"):
            encoder_output, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.encoder_input, 
                self.ori_sents_length, dtype=tf.float32)
            post_sum_state = tf.concat(encoder_state, 1)
            encoder_output = tf.concat(encoder_output, 2)

        # response sentence encoder
        with variable_scope.variable_scope("encoder", reuse = True):
            decoder_state, decoder_last_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.decoder_input, 
                self.rep_sents_length, dtype=tf.float32)
            response_sum_state = tf.concat(decoder_last_state, 1)

        # recognition network
        with variable_scope.variable_scope("recog_net"):
            recog_input = tf.concat([post_sum_state, response_sum_state], 1)
            recog_mulogvar = tf.contrib.layers.fully_connected(recog_input, latent_size * 2, activation_fn=None, scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        # prior network
        with variable_scope.variable_scope("prior_net"):
            prior_fc1 = tf.contrib.layers.fully_connected(post_sum_state, latent_size * 2, activation_fn=tf.tanh, scope="fc1")
            prior_mulogvar = tf.contrib.layers.fully_connected(prior_fc1, latent_size * 2, activation_fn=None, scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

        latent_sample = tf.cond(self.use_prior,
                                lambda: sample_gaussian(prior_mu, prior_logvar),
                                lambda: sample_gaussian(recog_mu, recog_logvar))


        # classifier
        with variable_scope.variable_scope("classifier"):
            classifier_input = latent_sample
            pattern_fc1 = tf.contrib.layers.fully_connected(classifier_input, latent_size, activation_fn=tf.tanh, scope="pattern_fc1")
            self.pattern_logits = tf.contrib.layers.fully_connected(pattern_fc1, num_classes, activation_fn=None, scope="pattern_logits")

        self.label_embedding = tf.matmul(self.labels, self.pattern_embed)

        output_fn, my_sequence_loss = output_projection_layer(2*num_units, num_symbols, latent_size, num_embed_units, self.content_mask, self.rhetoric_mask)

        attention_keys, attention_values, attention_score_fn, attention_construct_fn = my_attention_decoder_fn.prepare_attention(encoder_output, 'luong', 2*num_units)

        with variable_scope.variable_scope("dec_start"):
            temp_start = tf.concat([post_sum_state, self.label_embedding, latent_sample], 1)
            dec_fc1 = tf.contrib.layers.fully_connected(temp_start, 2*num_units, activation_fn=tf.tanh, scope="dec_start_fc1")
            dec_fc2 = tf.contrib.layers.fully_connected(dec_fc1, 2*num_units, activation_fn=None, scope="dec_start_fc2")

        if is_train:
            # rnn decoder
            topic_memory = self.update_memory(topic_memory, encoder_output)
            extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1)

            decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train(dec_fc2, 
                attention_keys, attention_values, attention_score_fn, attention_construct_fn, extra_info)
            self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_train, 
                self.decoder_input, self.rep_sents_length, scope = "decoder")

            # calculate the loss
            self.decoder_loss = my_loss.sequence_loss(logits = self.decoder_output, 
                targets = self.rep_sents_target, weights = self.decoder_mask,
                extra_information = latent_sample, label_embedding = self.label_embedding, softmax_loss_function = my_sequence_loss)
            temp_klloss = tf.reduce_mean(gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar))
            self.kl_weight = tf.minimum(tf.to_float(self.global_t)/full_kl_step, 1.0)
            self.klloss = self.kl_weight * temp_klloss
            temp_labels = tf.argmax(self.labels, 1)
            self.classifierloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.pattern_logits, labels=temp_labels))
            self.loss = self.decoder_loss + self.klloss + self.classifierloss  # need to anneal the kl_weight
            
            # building graph finished and get all parameters
            self.params = tf.trainable_variables()
        
            # initialize the training process
            self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor)
            self.global_step = tf.Variable(0, trainable=False)
            
            # calculate the gradient of parameters
            opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
            gradients = tf.gradients(self.loss, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 
                    max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 
                    global_step=self.global_step)

        else:
            # rnn decoder
            topic_memory = self.update_memory(topic_memory, encoder_output)
            extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1)
            decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference(output_fn, 
                dec_fc2, attention_keys, attention_values, attention_score_fn, 
                attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, extra_info)
            self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_inference, scope="decoder")
            self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
                [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index)
            
            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, 
                max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
Ejemplo n.º 7
0
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 vocab=None,
                 embed=None,
                 name_scope=None,
                 learning_rate=0.0001,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5,
                 l2_lambda=0.2):

        self.posts = tf.placeholder(tf.string, shape=[None,
                                                      None])  # batch * len
        self.posts_length = tf.placeholder(tf.int32, shape=[None])  # batch
        self.responses = tf.placeholder(tf.string, shape=[None,
                                                          None])  # batch*len
        self.responses_length = tf.placeholder(tf.int32, shape=[None])  # batch
        self.generation = tf.placeholder(tf.string, shape=[None,
                                                           None])  # batch*len
        self.generation_length = tf.placeholder(tf.int32,
                                                shape=[None])  # batch

        # build the vocab table (string to index)
        self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(
            self.symbols,
            tf.Variable(
                np.array([i for i in range(num_symbols)], dtype=np.int32),
                False)),
                                      default_value=UNK_ID,
                                      name="symbol2index")

        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.posts_input = self.symbol2index.lookup(
            self.posts)  # batch * utter_len
        self.posts_input_embed = tf.nn.embedding_lookup(
            self.embed, self.posts_input)  #batch * utter_len * embed_unit
        self.responses_input = self.symbol2index.lookup(self.responses)
        self.responses_input_embed = tf.nn.embedding_lookup(
            self.embed, self.responses_input)  # batch * utter_len * embed_unit
        self.generation_input = self.symbol2index.lookup(self.generation)
        self.generation_input_embed = tf.nn.embedding_lookup(
            self.embed,
            self.generation_input)  # batch * utter_len * embed_unit

        # Construct bidirectional GRU cells for encoder / decoder
        cell_fw_post = GRUCell(num_units)
        cell_bw_post = GRUCell(num_units)
        cell_fw_resp = GRUCell(num_units)
        cell_bw_resp = GRUCell(num_units)

        # Encode the post sequence
        with variable_scope.variable_scope("post_encoder"):
            posts_state, posts_final_state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw_post,
                cell_bw_post,
                self.posts_input_embed,
                self.posts_length,
                dtype=tf.float32)
            posts_final_state_bid = tf.concat(
                posts_final_state, 1)  # batch_size * (2 * num_units)

        # Encode the real response sequence
        with variable_scope.variable_scope("resp_encoder"):
            responses_state, responses_final_state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw_resp,
                cell_bw_resp,
                self.responses_input_embed,
                self.responses_length,
                dtype=tf.float32)
            responses_final_state_bid = tf.concat(responses_final_state, 1)

        # Encode the generated response sequence
        with variable_scope.variable_scope("resp_encoder", reuse=True):
            generation_state, generation_final_state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw_resp,
                cell_bw_resp,
                self.generation_input_embed,
                self.generation_length,
                dtype=tf.float32)
            generation_final_state_bid = tf.concat(generation_final_state, 1)

        # Calculate the relevance score between post and real response
        with variable_scope.variable_scope("calibration"):
            self.W = tf.get_variable('W', [2 * num_units, 2 * num_units],
                                     tf.float32)
            vec_post = tf.reshape(posts_final_state_bid,
                                  [-1, 1, 2 * num_units])
            vec_resp = tf.reshape(responses_final_state_bid,
                                  [-1, 2 * num_units, 1])
            attn_score_true = tf.einsum(
                'aij,ajk->aik', tf.einsum('aij,jk->aik', vec_post, self.W),
                vec_resp)
            attn_score_true = tf.reshape(attn_score_true, [-1, 1])
            fc_true_input = tf.concat([
                posts_final_state_bid, responses_final_state_bid,
                attn_score_true
            ], 1)

            self.output_fc_W = tf.get_variable("output_fc_W",
                                               [4 * num_units + 1, num_units],
                                               tf.float32)
            self.output_fc_b = tf.get_variable("output_fc_b", [num_units],
                                               tf.float32)
            fc_true = tf.nn.tanh(
                tf.nn.xw_plus_b(fc_true_input, self.output_fc_W,
                                self.output_fc_b))  # batch_size

            self.output_W = tf.get_variable("output_W", [num_units, 1],
                                            tf.float32)
            self.output_b = tf.get_variable("output_b", [1], tf.float32)
            self.cost_true = tf.nn.sigmoid(
                tf.nn.xw_plus_b(fc_true, self.output_W,
                                self.output_b))  # batch_size

        # Calculate the relevance score between post and generated response
        with variable_scope.variable_scope("calibration", reuse=True):
            vec_gen = tf.reshape(generation_final_state_bid,
                                 [-1, 2 * num_units, 1])
            attn_score_false = tf.einsum(
                'aij,ajk->aik', tf.einsum('aij,jk->aik', vec_post, self.W),
                vec_gen)
            attn_score_false = tf.reshape(attn_score_false, [-1, 1])
            fc_false_input = tf.concat([
                posts_final_state_bid, generation_final_state_bid,
                attn_score_false
            ], 1)
            fc_false = tf.nn.tanh(
                tf.nn.xw_plus_b(fc_false_input, self.output_fc_W,
                                self.output_fc_b))  # batch_size
            self.cost_false = tf.nn.sigmoid(
                tf.nn.xw_plus_b(fc_false, self.output_W,
                                self.output_b))  # batch_size

        self.PR_cost = tf.reduce_mean(
            tf.reduce_sum(tf.square(self.cost_true - 1.0), axis=1))
        self.PG_cost = tf.reduce_mean(
            tf.reduce_sum(tf.square(self.cost_false), axis=1))

        # Use the loss similar to least square GAN
        self.cost = self.PR_cost / 2.0 + self.PG_cost / 2.0 + l2_lambda * (
            tf.nn.l2_loss(self.output_fc_W) + tf.nn.l2_loss(self.output_fc_b) +
            tf.nn.l2_loss(self.output_W) + tf.nn.l2_loss(self.output_b) +
            tf.nn.l2_loss(self.W))

        # building graph finished and get all parameters
        self.params = [
            k for k in tf.trainable_variables() if name_scope in k.name
        ]

        # initialize the training process
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)
        self.adv_global_step = tf.Variable(0, trainable=False)

        # calculate the gradient of parameters
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.cost, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)
        self.reward = tf.reduce_sum(self.cost_false, axis=1)  # batch

        all_variables = [
            k for k in tf.global_variables() if name_scope in k.name
        ]
        self.saver = tf.train.Saver(all_variables,
                                    write_version=tf.train.SaverDef.V2,
                                    max_to_keep=5,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
        self.adv_saver = tf.train.Saver(all_variables,
                                        write_version=tf.train.SaverDef.V2,
                                        max_to_keep=5,
                                        pad_step_number=True,
                                        keep_checkpoint_every_n_hours=1.0)
Ejemplo n.º 8
0
class Seq2SeqModel(object):
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 num_layers,
                 is_train,
                 vocab=None,
                 embed=None,
                 learning_rate=0.5,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5.0,
                 num_samples=512,
                 max_length=30,
                 use_lstm=False):

        self.posts = tf.placeholder(tf.string, shape=(None, None))  # batch*len
        self.posts_length = tf.placeholder(tf.int32, shape=(None))  # batch
        self.responses = tf.placeholder(tf.string,
                                        shape=(None, None))  # batch*len
        self.responses_length = tf.placeholder(tf.int32, shape=(None))  # batch

        # build the vocab table (string to index)
        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.'] * num_symbols),
                                       name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(
            self.symbols,
            tf.Variable(
                np.array([i for i in range(num_symbols)], dtype=np.int32),
                False)),
                                      default_value=UNK_ID,
                                      name="symbol2index")

        self.posts_input = self.symbol2index.lookup(self.posts)  # batch*len
        self.responses_target = self.symbol2index.lookup(
            self.responses)  #batch*len

        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        self.responses_input = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(
            self.embed, self.posts_input)  #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        if use_lstm:
            cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
        else:
            cell = MultiRNNCell([GRUCell(num_units)] * num_layers)

        # rnn encoder
        encoder_output, encoder_state = dynamic_rnn(cell,
                                                    self.encoder_input,
                                                    self.posts_length,
                                                    dtype=tf.float32,
                                                    scope="encoder")

        # get output projection function
        output_fn, sampled_sequence_loss = output_projection_layer(
            num_units, num_symbols, num_samples)

        # get attention function
        attention_keys, attention_values, attention_score_fn, attention_construct_fn \
                = attention_decoder_fn.prepare_attention(encoder_output, 'luong', num_units)

        # get decoding loop function
        decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(
            encoder_state, attention_keys, attention_values,
            attention_score_fn, attention_construct_fn)
        decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(
            output_fn, encoder_state, attention_keys, attention_values,
            attention_score_fn, attention_construct_fn, self.embed, GO_ID,
            EOS_ID, max_length, num_symbols)

        if is_train:
            # rnn decoder
            self.decoder_output, _, _ = dynamic_rnn_decoder(
                cell,
                decoder_fn_train,
                self.decoder_input,
                self.responses_length,
                scope="decoder")
            # calculate the loss of decoder
            self.decoder_loss = sampled_sequence_loss(self.decoder_output,
                                                      self.responses_target,
                                                      self.decoder_mask)

            # building graph finished and get all parameters
            self.params = tf.trainable_variables()

            # initialize the training process
            self.learning_rate = tf.Variable(float(learning_rate),
                                             trainable=False,
                                             dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                self.learning_rate * learning_rate_decay_factor)
            self.global_step = tf.Variable(0, trainable=False)

            # calculate the gradient of parameters
            opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            gradients = tf.gradients(self.decoder_loss, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
                gradients, max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients,
                                                  self.params),
                                              global_step=self.global_step)

        else:
            # rnn decoder
            self.decoder_distribution, _, _ = dynamic_rnn_decoder(
                cell, decoder_fn_inference, scope="decoder")

            # generating the response
            #self.generation_index = tf.argmax(self.decoder_distribution, 2)
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, num_symbols - 2],
                         2)[1], 2) + 2  # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols,
                                                     self.generation_index)

            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2,
                                    max_to_keep=3,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, session, data, forward_only=False):
        input_feed = {
            self.posts: data['posts'],
            self.posts_length: data['posts_length'],
            self.responses: data['responses'],
            self.responses_length: data['responses_length']
        }
        if forward_only:
            output_feed = [self.decoder_loss]
        else:
            output_feed = [self.decoder_loss, self.gradient_norm, self.update]
        return session.run(output_feed, input_feed)

    def inference(self, session, data):
        input_feed = {
            self.posts: data['posts'],
            self.posts_length: data['posts_length']
        }
        output_feed = [self.generation]
        return session.run(output_feed, input_feed)
Ejemplo n.º 9
0
class lm_model(object):
    def __init__(self,
                 num_symbols,
                 num_embed_units,
                 num_units,
                 vocab=None,
                 embed=None,
                 name_scope=None,
                 learning_rate=0.001,
                 learning_rate_decay_factor=0.95,
                 max_gradient_norm=5,
                 num_samples=512,
                 max_length=30):

        self.posts = tf.placeholder(tf.string, shape=[None,
                                                      None])  # batch * len
        self.posts_length = tf.placeholder(tf.int32, shape=[None])  # batch
        self.responses = tf.placeholder(tf.string, shape=[None,
                                                          None])  # batch*len
        self.responses_length = tf.placeholder(tf.int32, shape=[None])  # batch

        # build the vocab table (string to index)
        self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(
            self.symbols,
            tf.Variable(
                np.array([i for i in range(num_symbols)], dtype=np.int32),
                False)),
                                      default_value=UNK_ID,
                                      name="symbol2index")

        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed',
                                         [num_symbols, num_embed_units],
                                         tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.posts_input = self.symbol2index.lookup(
            self.posts)  # batch * utter_len
        self.encoder_input = tf.nn.embedding_lookup(
            self.embed, self.posts_input)  # batch * utter_len * embed_unit

        self.responses_target = self.symbol2index.lookup(
            self.responses)  # batch*len
        batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(
            self.responses)[1]
        self.responses_input = tf.concat([
            tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
            tf.split(self.responses_target, [decoder_len - 1, 1], 1)[0]
        ], 1)  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])  # batch * len

        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        # Construct GRU cells for encoder / decoder
        cell_enc = GRUCell(num_units)
        cell_dec = GRUCell(num_units)

        # Encode the post
        _, encoder_state = tf.nn.dynamic_rnn(cell_enc,
                                             self.encoder_input,
                                             self.posts_length,
                                             dtype=tf.float32,
                                             scope="encoder")

        output_fn, sampled_sequence_loss = output_projection_layer(
            num_units, num_symbols, num_samples)

        # Decode the response (training phase)
        with variable_scope.variable_scope('decoder'):
            decoder_fn_train = my_simple_decoder_fn.simple_decoder_fn_train(
                encoder_state)
            self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(
                cell_dec,
                decoder_fn_train,
                self.decoder_input,
                self.responses_length,
                scope="decoder_rnn")
            self.decoder_loss, self.all_decoder_output = my_loss.sequence_loss(
                self.decoder_output,
                self.responses_target,
                self.decoder_mask,
                softmax_loss_function=sampled_sequence_loss)

        # Decode the response (inference phase)
        with variable_scope.variable_scope('decoder', reuse=True):
            decoder_fn_inference = my_simple_decoder_fn.simple_decoder_fn_inference(
                output_fn, encoder_state, self.embed, GO_ID, EOS_ID,
                max_length, num_symbols)
            self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(
                cell_dec, decoder_fn_inference, scope="decoder_rnn")
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, num_symbols - 2],
                         2)[1], 2) + 2  # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols,
                                                     self.generation_index)

        self.params = [
            k for k in tf.trainable_variables() if name_scope in k.name
        ]

        # Initialize the training process
        self.learning_rate = tf.Variable(float(learning_rate),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * learning_rate_decay_factor)
        self.global_step = tf.Variable(0, trainable=False)

        # Calculate the gradient of parameters
        self.cost = tf.reduce_mean(self.decoder_loss)
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.cost, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        all_variables = [
            k for k in tf.global_variables() if name_scope in k.name
        ]
        self.saver = tf.train.Saver(all_variables,
                                    write_version=tf.train.SaverDef.V2,
                                    max_to_keep=3,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    # Conduct a step of training
    def step(self, session, data, forward_only=False):
        input_feed = {
            self.posts: data['query'],
            self.posts_length: data['len_query'],
            self.responses: data['ans'],
            self.responses_length: data['len_ans']
        }
        if forward_only:
            output_feed = [self.cost]
        else:
            output_feed = [self.cost, self.gradient_norm, self.update]
        return session.run(output_feed, input_feed)

    # Get the language model score during inference
    def inference(self, session, data):
        input_feed = {
            self.posts: data['query'],
            self.posts_length: data['len_query'],
            self.responses: data['ans'],
            self.responses_length: data['len_ans']
        }
        output_feed = [self.all_decoder_output]
        return session.run(output_feed, input_feed)

    # Acquire a batch of data used for training / test
    def gen_train_batched_data(self, data, config):
        len_query = [len(p['query']) + 1 for p in data]
        len_ans = [len(p['ans']) + 1 for p in data]

        def padding(sent, l, is_query=False):
            if config.direction == 0 and is_query == False:
                sent.reverse()
            return sent + ['_EOS'] + ['_PAD'] * (l - len(sent) - 1)

        batched_query = [
            padding(p['query'], max(len_query), True) for p in data
        ]
        batched_ans = [padding(p['ans'], max(len_ans)) for p in data]
        batched_data = {
            'query': np.array(batched_query),
            'len_query': np.array(len_query, dtype=np.int32),
            'ans': np.array(batched_ans),
            'len_ans': np.array(len_ans, dtype=np.int32)
        }
        return batched_data