コード例 #1
0
    def __init__(self, tfFLAGS, embed=None):
        self.vocab_size = tfFLAGS.vocab_size
        self.embed_size = tfFLAGS.embed_size
        self.num_units = tfFLAGS.num_units
        self.num_layers = tfFLAGS.num_layers
        self.beam_width = tfFLAGS.beam_width
        self.use_lstm = tfFLAGS.use_lstm
        self.attn_mode = tfFLAGS.attn_mode
        self.train_keep_prob = tfFLAGS.keep_prob
        self.max_decode_len = tfFLAGS.max_decode_len
        self.bi_encode = tfFLAGS.bi_encode
        self.recog_hidden_units = tfFLAGS.recog_hidden_units
        self.prior_hidden_units = tfFLAGS.prior_hidden_units
        self.z_dim = tfFLAGS.z_dim
        self.full_kl_step = tfFLAGS.full_kl_step

        self.global_step = tf.Variable(0, name="global_step", trainable=False)
        self.max_gradient_norm = 5.0
        if tfFLAGS.opt == 'SGD':
            self.learning_rate = tf.Variable(float(tfFLAGS.learning_rate),
                                             trainable=False,
                                             dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(
                self.learning_rate * tfFLAGS.learning_rate_decay_factor)
            self.opt = tf.train.GradientDescentOptimizer(self.learning_rate)
        elif tfFLAGS.opt == 'Momentum':
            self.opt = tf.train.MomentumOptimizer(
                learning_rate=tfFLAGS.learning_rate, momentum=tfFLAGS.momentum)
        else:
            self.learning_rate = tfFLAGS.learning_rate
            self.opt = tf.train.AdamOptimizer()

        self._make_input(embed)

        with tf.variable_scope("output_layer"):
            self.output_layer = Dense(
                self.vocab_size,
                kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))

        with tf.variable_scope("encoders",
                               initializer=tf.orthogonal_initializer()):
            self.enc_post_outputs, self.enc_post_state = self._build_encoder(
                scope='post_encoder',
                inputs=self.enc_post,
                sequence_length=self.post_len)
            self.enc_ref_outputs, self.enc_ref_state = self._build_encoder(
                scope='ref_encoder',
                inputs=self.enc_ref,
                sequence_length=self.ref_len)
            self.enc_response_outputs, self.enc_response_state = self._build_encoder(
                scope='resp_encoder',
                inputs=self.enc_response,
                sequence_length=self.response_len)

            self.post_state = self._get_representation_from_enc_state(
                self.enc_post_state)
            self.ref_state = self._get_representation_from_enc_state(
                self.enc_ref_state)
            self.response_state = self._get_representation_from_enc_state(
                self.enc_response_state)
            self.cond_embed = tf.concat([self.post_state, self.ref_state],
                                        axis=-1)

        with tf.variable_scope("RecognitionNetwork"):
            recog_input = tf.concat([self.cond_embed, self.response_state],
                                    axis=-1)
            recog_hidden = tf.layers.dense(inputs=recog_input,
                                           units=self.recog_hidden_units,
                                           activation=tf.nn.tanh)
            recog_mulogvar = tf.layers.dense(inputs=recog_hidden,
                                             units=self.z_dim * 2,
                                             activation=None)
            # recog_mulogvar = tf.layers.dense(inputs=recog_input, units=self.z_dim * 2, activation=None)
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=-1)

        with tf.variable_scope("PriorNetwork"):
            prior_input = self.cond_embed
            prior_hidden = tf.layers.dense(inputs=prior_input,
                                           units=self.prior_hidden_units,
                                           activation=tf.nn.tanh)
            prior_mulogvar = tf.layers.dense(inputs=prior_hidden,
                                             units=self.z_dim * 2,
                                             activation=None)
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=-1)

        with tf.variable_scope("GenerationNetwork"):
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar),
                name='latent_sample')

            gen_input = tf.concat([self.cond_embed, latent_sample], axis=-1)
            if self.use_lstm:
                self.dec_init_state = tuple([
                    tf.contrib.rnn.LSTMStateTuple(
                        c=tf.layers.dense(inputs=gen_input,
                                          units=self.num_units,
                                          activation=None),
                        h=tf.layers.dense(inputs=gen_input,
                                          units=self.num_units,
                                          activation=None))
                    for _ in range(self.num_layers)
                ])
                print self.dec_init_state
            else:
                self.dec_init_state = tuple([
                    tf.layers.dense(inputs=gen_input,
                                    units=self.num_units,
                                    activation=None)
                    for _ in range(self.num_layers)
                ])

            kld = gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar)
            self.avg_kld = tf.reduce_mean(kld)
            self.kl_weights = tf.minimum(
                tf.to_float(self.global_step) / self.full_kl_step, 1.0)
            self.kl_loss = self.kl_weights * self.avg_kld

        self._build_decoder()
        self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                    max_to_keep=1,
                                    pad_step_number=True,
                                    keep_checkpoint_every_n_hours=1.0)
        for var in tf.trainable_variables():
            print var
コード例 #2
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.topic_vocab = api.topic_vocab
        self.topic_vocab_size = len(self.topic_vocab)

        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.bow_weights = config.bow_weights

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None,
                                                        self.max_utt_len),
                                                 name="context")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_topics = tf.placeholder(dtype=tf.int32,
                                                shape=(None, ),
                                                name="output_topic")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_context_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        if config.use_hcf:
            with variable_scope.variable_scope("topicEmbedding"):
                t_embedding = tf.get_variable(
                    "embedding",
                    [self.topic_vocab_size, config.topic_embed_size],
                    dtype=tf.float32)
                topic_embedding = embedding_ops.embedding_lookup(
                    t_embedding, self.output_topics)

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            # context nn
            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     self.output_lens,
                                                     scope="sent_rnn",
                                                     reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_context_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                enc_last_state = tf.concat(enc_last_state, 1)

        # combine with other attributes
        if config.use_hcf:
            attribute_embedding = topic_embedding
            attribute_fc1 = layers.fully_connected(attribute_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

        cond_embedding = enc_last_state

        with variable_scope.variable_scope("recognitionNetwork"):
            if config.use_hcf:
                recog_input = tf.concat(
                    [cond_embedding, output_embedding, attribute_fc1], 1)
            else:
                recog_input = tf.concat([cond_embedding, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Y loss
            if config.use_hcf:
                meta_fc1 = layers.fully_connected(latent_sample,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="meta_fc1")
                if config.keep_prob < 1.0:
                    meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
                self.topic_logits = layers.fully_connected(
                    meta_fc1, self.topic_vocab_size, scope="topic_project")
                topic_prob = tf.nn.softmax(self.topic_logits)
                #pred_attribute_embedding = tf.matmul(topic_prob, t_embedding)
                pred_topic = tf.argmax(topic_prob, 1)
                pred_attribute_embedding = embedding_ops.embedding_lookup(
                    t_embedding, pred_topic)
                if forward:
                    selected_attribute_embedding = pred_attribute_embedding
                else:
                    selected_attribute_embedding = attribute_embedding
                dec_inputs = tf.concat(
                    [gen_inputs, selected_attribute_embedding], 1)
            else:
                self.topic_logits = tf.zeros(
                    (batch_size, self.topic_vocab_size))
                selected_attribute_embedding = None
                dec_inputs = gen_inputs

            # Decoder
            if config.num_layer > 1:
                dec_init_state = [
                    layers.fully_connected(dec_inputs,
                                           self.dec_cell_size,
                                           activation_fn=None,
                                           scope="init_state-%d" % i)
                    for i in range(config.num_layer)
                ]
                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens)
            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                """ as n-trial multimodal distribution. """
                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)
                bow_weights = tf.to_float(self.bow_weights)

                # reconstruct the meta info about X
                if config.use_hcf:
                    topic_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=self.topic_logits, labels=self.output_topics)
                    self.avg_topic_loss = tf.reduce_mean(topic_loss)
                else:
                    self.avg_topic_loss = 0.0

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = bow_weights * self.avg_bow_loss + self.avg_topic_loss + self.elbo

                tf.summary.scalar("topic_loss", self.avg_topic_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
コード例 #3
0
    def __init__(self,
                 sess,
                 config,
                 api,
                 log_dir,
                 forward,
                 scope=None):  # forward???
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.seen_intent = api.seen_intent
        self.rev_seen_intent = api.rev_seen_intent
        self.seen_intent_size = len(self.rev_seen_intent)
        self.unseen_intent = api.unseen_intent
        self.rev_unseen_intent = api.rev_unseen_intent
        self.unseen_intent_size = len(self.rev_unseen_intent)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.label_embed_size = config.label_embed_size
        self.latent_size = config.latent_size

        self.seed = config.seed
        self.use_ot_label = config.use_ot_label
        self.use_rand_ot_label = config.use_rand_ot_label  # Only valid if use_ot_label is true, whether use all other label
        self.use_rand_fixed_ot_label = config.use_rand_fixed_ot_label  # valid when use_ot_label=true and use_rand_ot_label=true
        if self.use_ot_label:
            self.rand_ot_label_num = config.rand_ot_label_num  # valid when use_ot_label=true and use_rand_ot_label=true
        else:
            self.rand_ot_label_num = self.seen_intent_size - 1

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.labels = tf.placeholder(
                dtype=tf.int32, shape=(None, ),
                name="labels")  # each utterance have a label, [batch_size,]
            self.ot_label_rand = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="ot_labels_rand")
            self.ot_labels_all = tf.placeholder(
                dtype=tf.int32, shape=(None, None),
                name="ot_labels_all")  #(batch_size, len(api.label_vocab)-1)

            # target response given the dialog context
            self.io_tokens = tf.placeholder(dtype=tf.int32,
                                            shape=(None, None),
                                            name="output_tokens")
            self.io_lens = tf.placeholder(dtype=tf.int32,
                                          shape=(None, ),
                                          name="output_lens")
            self.output_labels = tf.placeholder(dtype=tf.int32,
                                                shape=(None, ),
                                                name="output_labels")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(
                dtype=tf.bool, name="use_prior")  # whether use prior
            self.prior_mulogvar = tf.placeholder(
                dtype=tf.float32,
                shape=(None, config.latent_size * 2),
                name="prior_mulogvar")

            self.batch_size = tf.placeholder(dtype=tf.int32, name="batch_size")

        max_out_len = array_ops.shape(self.io_tokens)[1]
        # batch_size = array_ops.shape(self.io_tokens)[0]
        batch_size = self.batch_size

        with variable_scope.variable_scope("labelEmbedding",
                                           reuse=tf.AUTO_REUSE):
            self.la_embedding = tf.get_variable(
                "embedding", [self.seen_intent_size, config.label_embed_size],
                dtype=tf.float32)
            label_embedding = embedding_ops.embedding_lookup(
                self.la_embedding, self.output_labels)  # not use

        with variable_scope.variable_scope("wordEmbedding",
                                           reuse=tf.AUTO_REUSE):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32,
                trainable=False)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask  # boardcast, first row is all 0.

            io_embedding = embedding_ops.embedding_lookup(
                embedding, self.io_tokens)  # 3 dim

            if config.sent_type == "bow":
                io_embedding, _ = get_bow(io_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                io_embedding, _ = get_rnn_encode(io_embedding,
                                                 sent_cell,
                                                 self.io_lens,
                                                 scope="sent_rnn",
                                                 reuse=tf.AUTO_REUSE)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                io_embedding, _ = get_bi_rnn_encode(
                    io_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    self.io_lens,
                    scope="sent_bi_rnn",
                    reuse=tf.AUTO_REUSE
                )  # equal to x of the graph, (batch_size, 300*2)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # print('==========================', io_embedding) # Tensor("models_2/wordEmbedding/sent_bi_rnn/concat:0", shape=(?, 600), dtype=float32)

            # convert label into 1 hot
            my_label_one_hot = tf.one_hot(tf.reshape(self.labels, [-1]),
                                          depth=self.seen_intent_size,
                                          dtype=tf.float32)  # 2 dim
            if config.use_ot_label:
                if config.use_rand_ot_label:
                    ot_label_one_hot = tf.one_hot(tf.reshape(
                        self.ot_label_rand, [-1]),
                                                  depth=self.seen_intent_size,
                                                  dtype=tf.float32)
                    ot_label_one_hot = tf.reshape(
                        ot_label_one_hot,
                        [-1, self.seen_intent_size * self.rand_ot_label_num])
                else:
                    ot_label_one_hot = tf.one_hot(tf.reshape(
                        self.ot_labels_all, [-1]),
                                                  depth=self.seen_intent_size,
                                                  dtype=tf.float32)
                    ot_label_one_hot = tf.reshape(
                        ot_label_one_hot, [
                            -1, self.seen_intent_size *
                            (self.seen_intent_size - 1)
                        ]
                    )  # (batch_size, len(api.label_vocab)*(len(api.label_vocab)-1))

        with variable_scope.variable_scope("recognitionNetwork",
                                           reuse=tf.AUTO_REUSE):
            recog_input = io_embedding
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")  # config.latent_size=200
            recog_mu, recog_logvar = tf.split(
                recog_mulogvar, 2, axis=1
            )  # recognition network output. (batch_size, config.latent_size)

        with variable_scope.variable_scope("priorNetwork",
                                           reuse=tf.AUTO_REUSE):
            # p(xyz) = p(z)p(x|z)p(y|xz)
            # prior network parameter, assum the normal distribution
            # prior_mulogvar = tf.constant([[1] * config.latent_size + [0] * config.latent_size]*batch_size,
            #                              dtype=tf.float32, name="muvar") # can not use by this manner
            prior_mulogvar = self.prior_mulogvar
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,  # bool input
                lambda: sample_gaussian(prior_mu, prior_logvar
                                        ),  # equal to shape(prior_logvar)
                lambda: sample_gaussian(recog_mu, recog_logvar)
            )  # if ... else ..., (batch_size, config.latent_size)
            self.z = latent_sample

        with variable_scope.variable_scope("generationNetwork",
                                           reuse=tf.AUTO_REUSE):
            bow_loss_inputs = latent_sample  # (part of) response network input
            label_inputs = latent_sample
            dec_inputs = latent_sample

            # BOW loss
            if config.use_bow_loss:
                bow_fc1 = layers.fully_connected(
                    bow_loss_inputs,
                    400,
                    activation_fn=tf.tanh,
                    scope="bow_fc1")  # MLPb network fc layer
                # error1:ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.
                if config.keep_prob < 1.0:
                    bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
                self.bow_logits = layers.fully_connected(
                    bow_fc1,
                    self.vocab_size,
                    activation_fn=None,
                    scope="bow_project")  # MLPb network fc output

            # Y loss, include the other y.
            my_label_fc1 = layers.fully_connected(label_inputs,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="my_label_fc1")
            if config.keep_prob < 1.0:
                my_label_fc1 = tf.nn.dropout(my_label_fc1, config.keep_prob)

            # my_label_fc2 = layers.fully_connected(my_label_fc1, 400, activation_fn=tf.tanh, scope="my_label_fc2")
            # if config.keep_prob < 1.0:
            #     my_label_fc2 = tf.nn.dropout(my_label_fc2, config.keep_prob)

            self.my_label_logits = layers.fully_connected(
                my_label_fc1, self.seen_intent_size,
                scope="my_label_project")  # MLPy fc output
            my_label_prob = tf.nn.softmax(
                self.my_label_logits
            )  # softmax output, (batch_size, label_vocab_size)
            self.my_label_prob = my_label_prob
            pred_my_label_embedding = tf.matmul(
                my_label_prob, self.la_embedding
            )  # predicted my label y. (batch_size, label_embed_size)

            if config.use_ot_label:
                if config.use_rand_ot_label:  # use one random other label
                    ot_label_fc1 = layers.fully_connected(
                        label_inputs,
                        400,
                        activation_fn=tf.tanh,
                        scope="ot_label_fc1")
                    if config.keep_prob < 1.0:
                        ot_label_fc1 = tf.nn.dropout(ot_label_fc1,
                                                     config.keep_prob)
                    self.ot_label_logits = layers.fully_connected(
                        ot_label_fc1,
                        self.rand_ot_label_num * self.seen_intent_size,
                        scope="ot_label_rand_project")
                    ot_label_logits_split = tf.reshape(
                        self.ot_label_logits,
                        [-1, self.rand_ot_label_num, self.seen_intent_size])
                    ot_label_prob_short = tf.nn.softmax(ot_label_logits_split)
                    ot_label_prob = tf.reshape(
                        ot_label_prob_short,
                        [-1, self.rand_ot_label_num * self.seen_intent_size]
                    )  # (batch_size, self.rand_ot_label_num*self.label_vocab_size)
                    pred_ot_label_embedding = tf.reshape(
                        tf.matmul(ot_label_prob_short, self.la_embedding),
                        [self.label_embed_size * self.rand_ot_label_num
                         ])  # predicted other label y2.
                else:
                    ot_label_fc1 = layers.fully_connected(
                        label_inputs,
                        400,
                        activation_fn=tf.tanh,
                        scope="ot_label_fc1")
                    if config.keep_prob < 1.0:
                        ot_label_fc1 = tf.nn.dropout(ot_label_fc1,
                                                     config.keep_prob)
                    self.ot_label_logits = layers.fully_connected(
                        ot_label_fc1,
                        self.seen_intent_size * (self.seen_intent_size - 1),
                        scope="ot_label_all_project")
                    ot_label_logits_split = tf.reshape(
                        self.ot_label_logits,
                        [-1, self.seen_intent_size - 1, self.seen_intent_size])
                    ot_label_prob_short = tf.nn.softmax(ot_label_logits_split)
                    ot_label_prob = tf.reshape(
                        ot_label_prob_short, [
                            -1, self.seen_intent_size *
                            (self.seen_intent_size - 1)
                        ]
                    )  # (batch_size, self.label_vocab_size*(self.label_vocab_size-1))
                    pred_ot_label_embedding = tf.reshape(
                        tf.matmul(ot_label_prob_short, self.la_embedding),
                        [self.label_embed_size * (self.seen_intent_size - 1)]
                    )  # predicted other all label y. (batch_size, self.label_embed_size*(self.label_vocab_size-1))
                    # note:matmul can calc (3, 4, 5) × (5, 4) = (3, 4, 4)
            else:  # only use label y.
                self.ot_label_logits = None
                pred_ot_label_embedding = None

            # Decoder, Response Network
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:  # test
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=None)  # a function
                dec_input_embedding = None
                dec_seq_lens = None
            else:  # train
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, None)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.io_tokens
                )  # x 's embedding (batch_size, utt_len, embed_size)
                dec_input_embedding = dec_input_embedding[:, 0:
                                                          -1, :]  # ignore the last </s>
                dec_seq_lens = self.io_lens - 1  # input placeholder

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

                # print("=======", dec_input_embedding) # Tensor("models/decoder/strided_slice:0", shape=(?, ?, 200), dtype=float32)

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens
            )  # dec_outs [batch_size, seq, features]

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(
                    dec_outs, axis=2)))  # get softmax vec's max index
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(
                    dec_outs,
                    2)  # (batch_size, utt_len), each element is index of word

        if not forward:
            with variable_scope.variable_scope("loss", reuse=tf.AUTO_REUSE):
                labels = self.io_tokens[:,
                                        1:]  # not include the first word <s>, (batch_size, utt_len)
                label_mask = tf.to_float(tf.sign(labels))

                labels = tf.one_hot(labels,
                                    depth=self.vocab_size,
                                    dtype=tf.float32)

                print(dec_outs)
                print(labels)
                # Tensor("models_1/decoder/dynamic_rnn_decoder/transpose_1:0", shape=(?, ?, 892), dtype=float32)
                # Tensor("models_1/loss/strided_slice:0", shape=(?, ?), dtype=int32)
                # rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=dec_outs, labels=labels) # response network loss
                rc_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)  # response network loss
                # logits_size=[390,892] labels_size=[1170,892]
                rc_loss = tf.reduce_sum(
                    rc_loss * label_mask,
                    reduction_indices=1)  # (batch_size,), except the word unk
                self.avg_rc_loss = tf.reduce_mean(rc_loss)  # scalar
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                """ as n-trial multimodal distribution. """
                tile_bow_logits = tf.tile(
                    tf.expand_dims(self.bow_logits, 1),
                    [1, max_out_len - 1, 1
                     ])  # (batch_size, max_out_len-1, vocab_size)
                bow_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels
                ) * label_mask  # labels shape less than logits shape, (batch_size, max_out_len-1)
                bow_loss = tf.reduce_sum(bow_loss,
                                         reduction_indices=1)  # (batch_size, )
                self.avg_bow_loss = tf.reduce_mean(bow_loss)  # scalar

                # the label y
                my_label_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=my_label_prob,
                    labels=my_label_one_hot)  # label (batch_size,)
                self.avg_my_label_loss = tf.reduce_mean(my_label_loss)
                if config.use_ot_label:
                    ot_label_loss = -tf.nn.softmax_cross_entropy_with_logits(
                        logits=ot_label_prob, labels=ot_label_one_hot)
                    self.avg_ot_label_loss = tf.reduce_mean(ot_label_loss)
                else:
                    self.avg_ot_label_loss = 0.0

                kld = gaussian_kld(
                    recog_mu, recog_logvar, prior_mu,
                    prior_logvar)  # kl divergence, (batch_size,)
                self.avg_kld = tf.reduce_mean(kld)  # scalar
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld  # Restructure loss and kl divergence
                #=====================================================================================================total loss====================================================#
                if config.use_rand_ot_label:
                    aug_elbo = self.avg_bow_loss + 1000 * self.avg_my_label_loss + 10 * self.avg_ot_label_loss + self.elbo  # augmented loss
                    # (1/self.rand_ot_label_num)*
                else:
                    aug_elbo = self.avg_bow_loss + 1000 * self.avg_my_label_loss + 10 * self.avg_ot_label_loss + self.elbo  # augmented loss
                    # (1/(self.label_vocab_size-1))*

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)
                tf.summary.scalar("my_label_loss", self.avg_my_label_loss)
                tf.summary.scalar("ot_label_loss", self.avg_ot_label_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)  # probability
                self.log_q_z_xy = norm_log_liklihood(
                    latent_sample, recog_mu, recog_logvar)  # probability
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
        print('model establish finish!')
コード例 #4
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab  # index2word
        self.rev_vocab = api.rev_vocab  # word2index
        self.vocab_size = len(self.vocab)  # vocab size
        self.emotion_vocab = api.emotion_vocab  # index2emotion
        self.emotion_vocab_size = len(self.emotion_vocab)
        # self.da_vocab = api.dialog_act_vocab
        # self.da_vocab_size = len(self.da_vocab)

        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]

        self.context_cell_size = config.cxt_cell_size  # dont need
        self.sent_cell_size = config.sent_cell_size  # for encode
        self.dec_cell_size = config.dec_cell_size  # for decode

        with tf.name_scope("io"):
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None,
                                                        self.max_utt_len),
                                                 name="input_contexts")
            # self.floors = tf.placeholder(dtype=tf.int32, shape=(None, None), name="floor")
            self.input_lens = tf.placeholder(dtype=tf.int32,
                                             shape=(None, ),
                                             name="input_lens")
            self.input_emotions = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, ),
                                                 name="input_emotions")
            # self.my_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="my_profile")
            # self.ot_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="ot_profile")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_emotions = tf.placeholder(dtype=tf.int32,
                                                  shape=(None, ),
                                                  name="output_emotions")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_dialog_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("emotionEmbedding"):
            t_embedding = tf.get_variable(
                "embedding",
                [self.emotion_vocab_size, config.topic_embed_size],
                dtype=tf.float32)
            inp_emotion_embedding = embedding_ops.embedding_lookup(
                t_embedding, self.input_emotions)
            outp_emotion_embedding = embedding_ops.embedding_lookup(
                t_embedding, self.output_emotions)

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            if config.sent_type == "rnn":
                enc_cell = self.get_rnncell(config.cell_type,
                                            self.context_cell_size,
                                            keep_prob=1.0,
                                            num_layer=config.num_layer)
                _, enc_last_state = tf.nn.dynamic_rnn(
                    enc_cell,
                    input_embedding,
                    dtype=tf.float32,
                    sequence_length=self.input_lens)

                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                # input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn")

                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        sent_cell,
                                                        self.output_lens,
                                                        scope="sent_rnn")

            elif config.sent_type == "bi_rnn":
                fwd_enc_cell = self.get_rnncell(config.cell_type,
                                                self.context_cell_size,
                                                keep_prob=1.0,
                                                num_layer=config.num_layer)
                bwd_enc_cell = self.get_rnncell(config.cell_type,
                                                self.context_cell_size,
                                                keep_prob=1.0,
                                                num_layer=config.num_layer)
                _, enc_last_state = tf.nn.bidirectional_dynamic_rnn(
                    fwd_enc_cell,
                    bwd_enc_cell,
                    input_embedding,
                    dtype=tf.float32,
                    sequence_length=self.input_lens)
                enc_last_state = enc_last_state[0] + enc_last_state[1]

                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                # input_embedding, sent_size = get_bi_rnn_encode(input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn")

            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [rnn, bi_rnn]")

            # reshape input into dialogs
            # input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size])
            # if config.keep_prob < 1.0:
            #     input_embedding = tf.nn.dropout(input_embedding, config.keep_prob)

            # convert floors into 1 hot
            # floor_one_hot = tf.one_hot(tf.reshape(self.floors, [-1]), depth=2, dtype=tf.float32)
            # floor_one_hot = tf.reshape(floor_one_hot, [-1, max_dialog_len, 2])

            # joint_embedding = tf.concat([input_embedding, floor_one_hot], 2, "joint_embedding")

        with variable_scope.variable_scope("contextRNN"):

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

            attribute_fc1 = layers.fully_connected(outp_emotion_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

            cond_embedding = tf.concat([inp_emotion_embedding, enc_last_state],
                                       1)

        with variable_scope.variable_scope("recognitionNetwork"):
            recog_input = tf.concat(
                [cond_embedding, output_embedding, attribute_fc1], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Y loss
            meta_fc1 = layers.fully_connected(gen_inputs,
                                              400,
                                              activation_fn=tf.tanh,
                                              scope="meta_fc1")
            if config.keep_prob < 1.0:
                meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
            self.da_logits = layers.fully_connected(meta_fc1,
                                                    self.emotion_vocab_size,
                                                    scope="da_project")
            da_prob = tf.nn.softmax(self.da_logits)
            pred_attribute_embedding = tf.matmul(da_prob, t_embedding)
            if forward:
                selected_attribute_embedding = pred_attribute_embedding
            else:
                selected_attribute_embedding = outp_emotion_embedding
            dec_inputs = tf.concat([gen_inputs, selected_attribute_embedding],
                                   1)

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))

                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)

                da_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=self.da_logits, labels=self.output_emotions)
                self.avg_da_loss = tf.reduce_mean(da_loss)

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo

                tf.summary.scalar("da_loss", self.avg_da_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
コード例 #5
0
    def __init__(self,
            num_symbols,
            num_embed_units,
            num_units,
            is_train,
            vocab=None,
            content_pos=None,
            rhetoric_pos = None,
            embed=None,
            learning_rate=0.1,
            learning_rate_decay_factor=0.9995,
            max_gradient_norm=5.0,
            max_length=30,
            latent_size=128,
            use_lstm=False,
            num_classes=3,
            full_kl_step=80000,
            mem_slot_num=4,
            mem_size=128):
        
        self.ori_sents = tf.placeholder(tf.string, shape=(None, None))
        self.ori_sents_length = tf.placeholder(tf.int32, shape=(None))
        self.rep_sents = tf.placeholder(tf.string, shape=(None, None))
        self.rep_sents_length = tf.placeholder(tf.int32, shape=(None))
        self.labels = tf.placeholder(tf.float32, shape=(None, num_classes))
        self.use_prior = tf.placeholder(tf.bool)
        self.global_t = tf.placeholder(tf.int32)
        self.content_mask = tf.reduce_sum(tf.one_hot(content_pos, num_symbols, 1.0, 0.0), axis = 0)
        self.rhetoric_mask = tf.reduce_sum(tf.one_hot(rhetoric_pos, num_symbols, 1.0, 0.0), axis = 0)

        topic_memory = tf.zeros(name="topic_memory", dtype=tf.float32,
                                  shape=[None, mem_slot_num, mem_size])

        w_topic_memory = tf.get_variable(name="w_topic_memory", dtype=tf.float32,
                                    initializer=tf.random_uniform([mem_size, mem_size], -0.1, 0.1))

        # build the vocab table (string to index)
        if is_train:
            self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
        else:
            self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols")
        self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols, 
            tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)), 
            default_value=UNK_ID, name="symbol2index")

        self.ori_sents_input = self.symbol2index.lookup(self.ori_sents)
        self.rep_sents_target = self.symbol2index.lookup(self.rep_sents)
        batch_size, decoder_len = tf.shape(self.rep_sents)[0], tf.shape(self.rep_sents)[1]
        self.rep_sents_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID,
            tf.split(self.rep_sents_target, [decoder_len-1, 1], 1)[0]], 1)
        self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.rep_sents_length-1,
            decoder_len), reverse=True, axis=1), [-1, decoder_len])        
        
        # build the embedding table (index to vector)
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed)

        self.pattern_embed = tf.get_variable('pattern_embed', [num_classes, num_embed_units], tf.float32)
        
        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.ori_sents_input)
        self.decoder_input = tf.nn.embedding_lookup(self.embed, self.rep_sents_input)

        if use_lstm:
            cell_fw = LSTMCell(num_units)
            cell_bw = LSTMCell(num_units)
            cell_dec = LSTMCell(2*num_units)
        else:
            cell_fw = GRUCell(num_units)
            cell_bw = GRUCell(num_units)
            cell_dec = GRUCell(2*num_units)

        # origin sentence encoder
        with variable_scope.variable_scope("encoder"):
            encoder_output, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.encoder_input, 
                self.ori_sents_length, dtype=tf.float32)
            post_sum_state = tf.concat(encoder_state, 1)
            encoder_output = tf.concat(encoder_output, 2)

        # response sentence encoder
        with variable_scope.variable_scope("encoder", reuse = True):
            decoder_state, decoder_last_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, self.decoder_input, 
                self.rep_sents_length, dtype=tf.float32)
            response_sum_state = tf.concat(decoder_last_state, 1)

        # recognition network
        with variable_scope.variable_scope("recog_net"):
            recog_input = tf.concat([post_sum_state, response_sum_state], 1)
            recog_mulogvar = tf.contrib.layers.fully_connected(recog_input, latent_size * 2, activation_fn=None, scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        # prior network
        with variable_scope.variable_scope("prior_net"):
            prior_fc1 = tf.contrib.layers.fully_connected(post_sum_state, latent_size * 2, activation_fn=tf.tanh, scope="fc1")
            prior_mulogvar = tf.contrib.layers.fully_connected(prior_fc1, latent_size * 2, activation_fn=None, scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

        latent_sample = tf.cond(self.use_prior,
                                lambda: sample_gaussian(prior_mu, prior_logvar),
                                lambda: sample_gaussian(recog_mu, recog_logvar))


        # classifier
        with variable_scope.variable_scope("classifier"):
            classifier_input = latent_sample
            pattern_fc1 = tf.contrib.layers.fully_connected(classifier_input, latent_size, activation_fn=tf.tanh, scope="pattern_fc1")
            self.pattern_logits = tf.contrib.layers.fully_connected(pattern_fc1, num_classes, activation_fn=None, scope="pattern_logits")

        self.label_embedding = tf.matmul(self.labels, self.pattern_embed)

        output_fn, my_sequence_loss = output_projection_layer(2*num_units, num_symbols, latent_size, num_embed_units, self.content_mask, self.rhetoric_mask)

        attention_keys, attention_values, attention_score_fn, attention_construct_fn = my_attention_decoder_fn.prepare_attention(encoder_output, 'luong', 2*num_units)

        with variable_scope.variable_scope("dec_start"):
            temp_start = tf.concat([post_sum_state, self.label_embedding, latent_sample], 1)
            dec_fc1 = tf.contrib.layers.fully_connected(temp_start, 2*num_units, activation_fn=tf.tanh, scope="dec_start_fc1")
            dec_fc2 = tf.contrib.layers.fully_connected(dec_fc1, 2*num_units, activation_fn=None, scope="dec_start_fc2")

        if is_train:
            # rnn decoder
            topic_memory = self.update_memory(topic_memory, encoder_output)
            extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1)

            decoder_fn_train = my_attention_decoder_fn.attention_decoder_fn_train(dec_fc2, 
                attention_keys, attention_values, attention_score_fn, attention_construct_fn, extra_info)
            self.decoder_output, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_train, 
                self.decoder_input, self.rep_sents_length, scope = "decoder")

            # calculate the loss
            self.decoder_loss = my_loss.sequence_loss(logits = self.decoder_output, 
                targets = self.rep_sents_target, weights = self.decoder_mask,
                extra_information = latent_sample, label_embedding = self.label_embedding, softmax_loss_function = my_sequence_loss)
            temp_klloss = tf.reduce_mean(gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar))
            self.kl_weight = tf.minimum(tf.to_float(self.global_t)/full_kl_step, 1.0)
            self.klloss = self.kl_weight * temp_klloss
            temp_labels = tf.argmax(self.labels, 1)
            self.classifierloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.pattern_logits, labels=temp_labels))
            self.loss = self.decoder_loss + self.klloss + self.classifierloss  # need to anneal the kl_weight
            
            # building graph finished and get all parameters
            self.params = tf.trainable_variables()
        
            # initialize the training process
            self.learning_rate = tf.Variable(float(learning_rate), trainable=False, dtype=tf.float32)
            self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor)
            self.global_step = tf.Variable(0, trainable=False)
            
            # calculate the gradient of parameters
            opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
            gradients = tf.gradients(self.loss, self.params)
            clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 
                    max_gradient_norm)
            self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 
                    global_step=self.global_step)

        else:
            # rnn decoder
            topic_memory = self.update_memory(topic_memory, encoder_output)
            extra_info = tf.concat([self.label_embedding, latent_sample, topic_memory], 1)
            decoder_fn_inference = my_attention_decoder_fn.attention_decoder_fn_inference(output_fn, 
                dec_fc2, attention_keys, attention_values, attention_score_fn, 
                attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols, extra_info)
            self.decoder_distribution, _, _ = my_seq2seq.dynamic_rnn_decoder(cell_dec, decoder_fn_inference, scope="decoder")
            self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
                [2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK
            self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index)
            
            self.params = tf.trainable_variables()

        self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2, 
                max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
コード例 #6
0
ファイル: cvae.py プロジェクト: mtian95/NeuralDialog-CVAE
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.num_topics = config.num_topics

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None,
                                                        self.max_utt_len),
                                                 name="dialog_context")
            self.floors = tf.placeholder(dtype=tf.float32,
                                         shape=(None, None),
                                         name="floor")  # TODO float
            self.floor_labels = tf.placeholder(dtype=tf.float32,
                                               shape=(None, 1),
                                               name="floor_labels")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")
            self.paragraph_topics = tf.placeholder(dtype=tf.float32,
                                                   shape=(None,
                                                          self.num_topics),
                                                   name="paragraph_topics")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_das = tf.placeholder(dtype=tf.float32,
                                             shape=(None, self.num_topics),
                                             name="output_dialog_acts")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_dialog_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            # embed the input
            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            # reshape embedding. -1 means that the first dimension can be whatever necessary to make the other 2 dimensions work w/the data
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            # embed the output so you can feed it into the VAE
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            #
            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     self.output_lens,
                                                     scope="sent_rnn",
                                                     reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_dialog_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

            # reshape floors
            floor = tf.reshape(self.floors, [-1, max_dialog_len, 1])

            joint_embedding = tf.concat([input_embedding, floor], 2,
                                        "joint_embedding")

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                joint_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

        # combine with other attributes
        if config.use_hcf:
            # TODO is this reshape ok?
            attribute_embedding = tf.reshape(
                self.output_das, [-1, self.num_topics])  # da_embedding
            attribute_fc1 = layers.fully_connected(attribute_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

        # conditions include topic and rnn of all previous birnn results and metadata about the two people
        cond_list = [self.paragraph_topics, enc_last_state]
        cond_embedding = tf.concat(cond_list, 1)  #float32

        with variable_scope.variable_scope("recognitionNetwork"):
            if config.use_hcf:
                recog_input = tf.concat(
                    [cond_embedding, output_embedding, attribute_fc1], 1)
            else:
                recog_input = tf.concat([cond_embedding, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            # mu and logvar are both vectors of size latent_size
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            # P(XYZ)=P(Z|X)P(X)P(Y|X,Z)
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample],
                                   1)  #float32

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Predicting Y (topic)
            if config.use_hcf:
                meta_fc1 = layers.fully_connected(gen_inputs,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="meta_fc1")
                if config.keep_prob < 1.0:
                    meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
                self.da_logits = layers.fully_connected(
                    meta_fc1, self.num_topics, scope="da_project")  # float32

                da_prob = tf.nn.softmax(self.da_logits)
                pred_attribute_embedding = da_prob  # TODO change the name of this to predicted sentence topic
                # pred_attribute_embedding = tf.matmul(da_prob, d_embedding)

                if forward:
                    selected_attribute_embedding = pred_attribute_embedding
                else:
                    selected_attribute_embedding = attribute_embedding
                dec_inputs = tf.concat(
                    [gen_inputs, selected_attribute_embedding], 1)

            # if use_hcf not on, the model won't predict the Y
            else:
                self.da_logits = tf.zeros((batch_size, self.num_topics))
                dec_inputs = gen_inputs
                selected_attribute_embedding = None

            # Predicting whether or not end of paragraph
            self.paragraph_end_logits = layers.fully_connected(
                gen_inputs,
                1,
                activation_fn=tf.tanh,
                scope="paragraph_end_fc1")  # float32

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        # initializer thing for lstm
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            # projects into thing of vocab size. TODO no softmax?
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    # get make of keep/throw-away
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens,
                name='output_node')

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):

                labels = self.output_tokens[:, 1:]  # correct word tokens
                label_mask = tf.to_float(tf.sign(labels))

                # Loss between words
                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))

                # BOW loss
                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)

                # Predict 0/1 (1 = last sentence in paragraph)
                end_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.floor_labels, logits=self.paragraph_end_logits)
                self.avg_end_loss = tf.reduce_mean(end_loss)

                # Topic prediction loss
                if config.use_hcf:
                    div_prob = tf.divide(self.da_logits, self.output_das)
                    self.avg_da_loss = tf.reduce_mean(
                        -tf.nn.softmax_cross_entropy_with_logits(
                            logits=self.da_logits, labels=div_prob))

                else:
                    self.avg_da_loss = 0.0

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo + self.avg_end_loss

                tf.summary.scalar("da_loss", self.avg_da_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)
                tf.summary.scalar("paragraph_end_loss", self.avg_end_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
コード例 #7
0
ファイル: model.py プロジェクト: songhaoyu/PerCVAE
    def __init__(self,
                 sess,
                 config,
                 api,
                 log_dir,
                 forward,
                 scope=None,
                 name=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.idf = api.index2idf
        self.gen_vocab_size = api.gen_vocab_size
        self.topic_vocab = api.topic_vocab
        self.topic_vocab_size = len(self.topic_vocab)
        self.da_vocab = api.dialog_act_vocab
        self.da_vocab_size = len(self.da_vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.max_per_len = config.max_per_len
        self.max_per_line = config.max_per_line
        self.max_per_words = config.max_per_words
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.memory_cell_size = config.memory_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.hops = config.hops
        self.batch_size = config.batch_size
        self.test_samples = config.test_samples
        self.balance_factor = config.balance_factor

        with tf.name_scope("io"):
            self.first_dimension_size = self.batch_size
            self.input_contexts = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, None, self.max_utt_len),
                name="dialog_context")
            self.floors = tf.placeholder(dtype=tf.int32,
                                         shape=(self.first_dimension_size,
                                                None),
                                         name="floor")
            self.context_lens = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, ),
                name="context_lens")
            self.topics = tf.placeholder(dtype=tf.int32,
                                         shape=(self.first_dimension_size, ),
                                         name="topics")
            self.personas = tf.placeholder(dtype=tf.int32,
                                           shape=(self.first_dimension_size,
                                                  self.max_per_line,
                                                  self.max_per_len),
                                           name="personas")
            self.persona_words = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, self.max_per_line,
                       self.max_per_len),
                name="persona_words")
            self.persona_position = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, None),
                name="persona_position")
            self.selected_persona = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, 1),
                name="selected_persona")

            self.query = tf.placeholder(dtype=tf.int32,
                                        shape=(self.first_dimension_size,
                                               self.max_utt_len),
                                        name="query")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, None),
                name="output_token")
            self.output_lens = tf.placeholder(
                dtype=tf.int32,
                shape=(self.first_dimension_size, ),
                name="output_lens")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_context_lines = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask
            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)
            persona_input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.personas, [-1]))
            persona_input_embedding = tf.reshape(
                persona_input_embedding,
                [-1, self.max_per_len, config.embed_size])
            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)
                persona_input_embedding, _ = get_bow(persona_input_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                _, input_embedding, sent_size = get_rnn_encode(
                    input_embedding, sent_cell, scope="sent_rnn")
                _, output_embedding, _ = get_rnn_encode(output_embedding,
                                                        sent_cell,
                                                        self.output_lens,
                                                        scope="sent_rnn",
                                                        reuse=True)
                _, persona_input_embedding, _ = get_rnn_encode(
                    persona_input_embedding,
                    sent_cell,
                    scope="sent_rnn",
                    reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_step_embedding, input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                _, output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                           fwd_sent_cell,
                                                           bwd_sent_cell,
                                                           self.output_lens,
                                                           scope="sent_bi_rnn",
                                                           reuse=True)
                _, persona_input_embedding, _ = get_bi_rnn_encode(
                    persona_input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn",
                    reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")
            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_context_lines, sent_size])
            self.input_step_embedding = input_step_embedding
            self.encoder_state_size = sent_size
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

        with variable_scope.variable_scope("personaMemory"):
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            A = tf.get_variable("persona_embedding_A",
                                [self.vocab_size, self.memory_cell_size],
                                dtype=tf.float32)
            A = A * embedding_mask
            C = []
            for hopn in range(self.hops):
                C.append(
                    tf.get_variable("persona_embedding_C_hop_{}".format(hopn),
                                    [self.vocab_size, self.memory_cell_size],
                                    dtype=tf.float32) * embedding_mask)

            q_emb = tf.nn.embedding_lookup(A, self.query)
            u_0 = tf.reduce_sum(q_emb, 1)
            u = [u_0]
            for hopn in range(self.hops):
                if hopn == 0:
                    m_emb_A = tf.nn.embedding_lookup(A, self.personas)
                    m_A = tf.reshape(m_emb_A, [
                        -1, self.max_per_len * self.max_per_line,
                        self.memory_cell_size
                    ])
                else:
                    with tf.variable_scope('persona_hop_{}'.format(hopn)):
                        m_emb_A = tf.nn.embedding_lookup(
                            C[hopn - 1], self.personas)
                        m_A = tf.reshape(m_emb_A, [
                            -1, self.max_per_len * self.max_per_line,
                            self.memory_cell_size
                        ])
                u_temp = tf.transpose(tf.expand_dims(u[-1], -1), [0, 2, 1])
                dotted = tf.reduce_sum(m_A * u_temp, 2)
                probs = tf.nn.softmax(dotted)
                probs_temp = tf.transpose(tf.expand_dims(probs, -1), [0, 2, 1])
                with tf.variable_scope('persona_hop_{}'.format(hopn)):
                    m_emb_C = tf.nn.embedding_lookup(
                        C[hopn],
                        tf.reshape(self.personas,
                                   [-1, self.max_per_len * self.max_per_line]))
                    m_emb_C = tf.expand_dims(m_emb_C, -2)
                    m_C = tf.reduce_sum(m_emb_C, axis=2)
                c_temp = tf.transpose(m_C, [0, 2, 1])
                o_k = tf.reduce_sum(c_temp * probs_temp, axis=2)
                u_k = u[-1] + o_k
                u.append(u_k)
            persona_memory = u[-1]

        with variable_scope.variable_scope("contextEmbedding"):
            context_layers = 2
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=context_layers)
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if context_layers > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

        cond_embedding = tf.concat([persona_memory, enc_last_state], 1)

        with variable_scope.variable_scope("recognitionNetwork"):
            recog_input = tf.concat(
                [cond_embedding, output_embedding, persona_memory], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("personaSelecting"):
            condition = tf.concat([persona_memory, latent_sample], 1)

            self.persona_dist = tf.nn.log_softmax(
                layers.fully_connected(condition,
                                       self.max_per_line,
                                       activation_fn=tf.tanh,
                                       scope="persona_dist"))
            select_temp = tf.expand_dims(
                tf.argmax(self.persona_dist, 1, output_type=tf.int32), 1)
            index_temp = tf.expand_dims(
                tf.range(0, self.first_dimension_size, dtype=tf.int32), 1)
            persona_select = tf.concat([index_temp, select_temp], 1)
            selected_words_ordered = tf.reshape(
                tf.gather_nd(self.persona_words, persona_select),
                [self.max_per_len * self.first_dimension_size])
            self.selected_words = tf.gather_nd(self.persona_words,
                                               persona_select)
            label = tf.reshape(
                selected_words_ordered,
                [self.max_per_len * self.first_dimension_size, 1])
            index = tf.reshape(
                tf.range(self.first_dimension_size, dtype=tf.int32),
                [self.first_dimension_size, 1])
            index = tf.reshape(
                tf.tile(index, [1, self.max_per_len]),
                [self.max_per_len * self.first_dimension_size, 1])

            concated = tf.concat([index, label], 1)
            true_labels = tf.where(selected_words_ordered > 0)
            concated = tf.gather_nd(concated, true_labels)
            self.persona_word_mask = tf.sparse_to_dense(
                concated, [self.first_dimension_size, self.vocab_size],
                config.perw_weight, 0.0)
            self.other_word_mask = tf.sparse_to_dense(
                concated, [self.first_dimension_size, self.vocab_size], 0.0,
                config.othw_weight)
            self.persona_word_mask = self.persona_word_mask * self.idf

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Y loss
            dec_inputs = gen_inputs
            selected_attribute_embedding = None
            self.da_logits = tf.zeros((batch_size, self.da_vocab_size))

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            pos_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            pos_cell = OutputProjectionWrapper(pos_cell, self.vocab_size)

            with variable_scope.variable_scope("position"):
                self.pos_w_1 = tf.get_variable("pos_w_1",
                                               [self.dec_cell_size, 2],
                                               dtype=tf.float32)
                self.pos_b_1 = tf.get_variable("pos_b_1", [2],
                                               dtype=tf.float32)

            def position_function(states, logp=False):
                states = tf.reshape(states, [-1, self.dec_cell_size])
                if logp:
                    return tf.reshape(
                        tf.nn.log_softmax(
                            tf.matmul(states, self.pos_w_1) + self.pos_b_1),
                        [self.first_dimension_size, -1, 2])
                return tf.reshape(
                    tf.nn.softmax(
                        tf.matmul(states, self.pos_w_1) + self.pos_b_1),
                    [self.first_dimension_size, -1, 2])

            if forward:
                loop_func = self.context_decoder_fn_inference(
                    position_function,
                    self.persona_word_mask,
                    self.other_word_mask,
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding,
                )
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = self.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1
                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            with variable_scope.variable_scope("dec_state"):
                dec_outs, _, final_context_state, rnn_states = dynamic_rnn_decoder(
                    dec_cell,
                    loop_func,
                    inputs=dec_input_embedding,
                    sequence_length=dec_seq_lens)
            with variable_scope.variable_scope("pos_state"):
                _, _, _, pos_states = dynamic_rnn_decoder(
                    pos_cell,
                    loop_func,
                    inputs=dec_input_embedding,
                    sequence_length=dec_seq_lens)

            self.position_dist = position_function(pos_states, logp=True)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)
        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))
                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                per_select_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tf.reshape(self.persona_dist,
                                      [self.first_dimension_size, 1, -1]),
                    labels=self.selected_persona)
                per_select_loss = tf.reduce_sum(per_select_loss,
                                                reduction_indices=1)
                self.avg_per_select_loss = tf.reduce_mean(per_select_loss)
                position_labels = self.persona_position[:, 1:]
                per_pos_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=self.position_dist, labels=position_labels)
                per_pos_loss = tf.reduce_sum(per_pos_loss, reduction_indices=1)
                self.avg_per_pos_loss = tf.reduce_mean(per_pos_loss)

                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)
                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)

                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.elbo + self.avg_bow_loss + 0.1 * self.avg_per_select_loss + 0.05 * self.avg_per_pos_loss

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("per_pos_loss", self.avg_per_pos_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
コード例 #8
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        # self.self_label = tf.placeholder(dtype=tf.bool,shape=(None), name="self_label")
        self.self_label = False
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size

        with tf.name_scope("io"):
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None),
                                                 name="dialog_context")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")

            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_input_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(
                embedding, self.input_contexts)

            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            if config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     self.output_lens,
                                                     scope="sent_rnn",
                                                     reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    self.context_lens,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            input_embedding = tf.expand_dims(input_embedding, axis=2)
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

        # input [enc_last_state, output_embedding] -- [c, x] --->z
        with variable_scope.variable_scope("recognitionNetwork"):
            recog_input = tf.concat([enc_last_state, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            # P(XYZ)=P(Z|X)P(X)P(Y|X,Z)
            prior_fc1 = layers.fully_connected(enc_last_state,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("label_encoder"):
            le_embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            le_embedding = self.embedding * le_embedding_mask

            le_input_embedding = embedding_ops.embedding_lookup(
                le_embedding, self.input_contexts)

            le_output_embedding = embedding_ops.embedding_lookup(
                le_embedding, self.output_tokens)

            if config.sent_type == "rnn":
                le_sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                                config.keep_prob, 1)
                le_input_embedding, le_sent_size = get_rnn_encode(
                    le_input_embedding, le_sent_cell, scope="sent_rnn")
                le_output_embedding, _ = get_rnn_encode(le_output_embedding,
                                                        le_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_rnn",
                                                        reuse=True)
            elif config.sent_type == "bi_rnn":
                le_fwd_sent_cell = self.get_rnncell("gru",
                                                    self.sent_cell_size,
                                                    keep_prob=1.0,
                                                    num_layer=1)
                le_bwd_sent_cell = self.get_rnncell("gru",
                                                    self.sent_cell_size,
                                                    keep_prob=1.0,
                                                    num_layer=1)
                le_input_embedding, le_sent_size = get_bi_rnn_encode(
                    le_input_embedding,
                    le_fwd_sent_cell,
                    le_bwd_sent_cell,
                    self.context_lens,
                    scope="sent_bi_rnn")
                le_output_embedding, _ = get_bi_rnn_encode(le_output_embedding,
                                                           le_fwd_sent_cell,
                                                           le_bwd_sent_cell,
                                                           self.output_lens,
                                                           scope="sent_bi_rnn",
                                                           reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            if config.keep_prob < 1.0:
                le_input_embedding = tf.nn.dropout(le_input_embedding,
                                                   config.keep_prob)

        # [le_enc_last_state, le_output_embedding]
        with variable_scope.variable_scope("lecontextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            le_input_embedding = tf.expand_dims(le_input_embedding, axis=2)
            _, le_enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                le_input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    le_enc_last_state = [temp.h for temp in le_enc_last_state]

                le_enc_last_state = tf.concat(le_enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    le_enc_last_state = le_enc_last_state.h
            best_en = tf.concat([le_enc_last_state, le_output_embedding], 1)

        with variable_scope.variable_scope("ggammaNet"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        200,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            input_embedding = tf.expand_dims(best_en, axis=2)
            _, zlabel = tf.nn.dynamic_rnn(enc_cell,
                                          input_embedding,
                                          dtype=tf.float32,
                                          sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    zlabel = [temp.h for temp in enc_last_state]

                zlabel = tf.concat(zlabel, 1)
            else:
                if config.cell_type == 'lstm':
                    zlabel = zlabel.h

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([enc_last_state, latent_sample], 1)

            dec_inputs = gen_inputs
            selected_attribute_embedding = None

            # Decoder_init_state
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("generationNetwork1"):
            gen_inputs_sl = tf.concat([le_enc_last_state, zlabel], 1)

            dec_inputs_sl = gen_inputs_sl
            selected_attribute_embedding = None

            # Decoder_init_state
            if config.num_layer > 1:
                dec_init_state_sl = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs_sl,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state_sl.append(temp_init)

                dec_init_state_sl = tuple(dec_init_state_sl)
            else:
                dec_init_state_sl = layers.fully_connected(dec_inputs_sl,
                                                           self.dec_cell_size,
                                                           activation_fn=None,
                                                           scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state_sl = rnn_cell.LSTMStateTuple(
                        dec_init_state_sl, dec_init_state_sl)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                loop_func_sl = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state_sl,
                    le_embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)

                dec_input_embedding = None
                dec_input_embedding_sl = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                loop_func_sl = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state_sl, selected_attribute_embedding)

                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding_sl = embedding_ops.embedding_lookup(
                    le_embedding, self.output_tokens)

                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_input_embedding_sl = dec_input_embedding_sl[:, 0:-1, :]

                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)
                    dec_input_embedding_sl = tf.nn.dropout(
                        dec_input_embedding_sl, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding_sl = dec_input_embedding_sl * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])
                    dec_input_embedding_sl = tf.reshape(
                        dec_input_embedding_sl,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens)

            dec_outs_sl, _, final_context_state_sl = dynamic_rnn_decoder(
                dec_cell,
                loop_func_sl,
                inputs=dec_input_embedding_sl,
                sequence_length=dec_seq_lens)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

            if final_context_state_sl is not None:
                final_context_state_sl = final_context_state_sl[:, 0:array_ops.
                                                                shape(
                                                                    dec_outs_sl
                                                                )[1]]
                mask_sl = tf.to_int32(
                    tf.sign(tf.reduce_max(dec_outs_sl, axis=2)))
                self.dec_out_words_sl = tf.multiply(
                    tf.reverse(final_context_state_sl, axis=[1]), mask_sl)
            else:
                self.dec_out_words_sl = tf.argmax(dec_outs_sl, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):

                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)

                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)

                sl_rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs_sl, labels=labels)
                sl_rc_loss = tf.reduce_sum(sl_rc_loss * label_mask,
                                           reduction_indices=1)
                self.sl_rc_loss = tf.reduce_mean(sl_rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                """ as n-trial multimodal distribution. """

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.label_loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=latent_sample, logits=zlabel))

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld

                self.cvae_loss = self.elbo + +0.1 * self.label_loss
                self.sl_loss = self.sl_rc_loss

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss - self.log_p_z +
                                                   self.log_q_z_xy)

            self.train_sl_ops = self.optimize(sess,
                                              config,
                                              self.sl_loss,
                                              log_dir,
                                              scope="SL")
            self.train_ops = self.optimize(sess,
                                           config,
                                           self.cvae_loss,
                                           log_dir,
                                           scope="CVAE")

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)