Esempio n. 1
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.num_topics = config.num_topics

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None,
                                                        self.max_utt_len),
                                                 name="dialog_context")
            self.floors = tf.placeholder(dtype=tf.float32,
                                         shape=(None, None),
                                         name="floor")  # TODO float
            self.floor_labels = tf.placeholder(dtype=tf.float32,
                                               shape=(None, 1),
                                               name="floor_labels")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")
            self.paragraph_topics = tf.placeholder(dtype=tf.float32,
                                                   shape=(None,
                                                          self.num_topics),
                                                   name="paragraph_topics")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_das = tf.placeholder(dtype=tf.float32,
                                             shape=(None, self.num_topics),
                                             name="output_dialog_acts")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")

        max_dialog_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            # embed the input
            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])

            # encode input using RNN w/GRU
            sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                         config.keep_prob, 1)
            input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                        sent_cell,
                                                        scope="sent_rnn")

            # reshape input
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_dialog_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

            # floor = probability that the next sentence is the last
            # TODO do we want this?
            floor = tf.reshape(self.floors, [-1, max_dialog_len, 1])

            joint_embedding = tf.concat([input_embedding, floor], 2,
                                        "joint_embedding")

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                joint_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

        # Final output from the encoder
        encoded_list = [self.paragraph_topics, enc_last_state]
        encoded_embedding = tf.concat(encoded_list, 1)

        with variable_scope.variable_scope("generationNetwork"):

            # predict whether the next sentence is the last one
            # TODO do we want this?
            self.paragraph_end_logits = layers.fully_connected(
                encoded_embedding,
                1,
                activation_fn=tf.tanh,
                scope="paragraph_end_fc1")

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(encoded_embedding,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        # initializer thing for lstm
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(encoded_embedding,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            # projects into thing of vocab size. TODO no softmax?
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=None)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, None)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    # get make of keep/throw-away
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens,
                name='output_node')

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):

                labels = self.output_tokens[:, 1:]  # correct word tokens
                label_mask = tf.to_float(tf.sign(labels))

                # Loss between words
                print "dec outs shape", dec_outs.get_shape()
                print "labels shape", labels.get_shape()

                # Loss between words
                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))

                # Predict 0/1 (1 = last sentence in paragraph)
                end_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.floor_labels, logits=self.paragraph_end_logits)
                self.avg_end_loss = tf.reduce_mean(end_loss)
                print "size of end loss", self.avg_end_loss.get_shape()

                total_loss = self.avg_rc_loss + self.avg_end_loss

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("paragraph_end_loss", self.avg_end_loss)

                self.summary_op = tf.summary.merge_all()

            self.optimize(sess, config, total_loss, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
    def __init__(self,
                 sess,
                 config,
                 api,
                 log_dir,
                 forward,
                 scope=None):  # forward???
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.seen_intent = api.seen_intent
        self.rev_seen_intent = api.rev_seen_intent
        self.seen_intent_size = len(self.rev_seen_intent)
        self.unseen_intent = api.unseen_intent
        self.rev_unseen_intent = api.rev_unseen_intent
        self.unseen_intent_size = len(self.rev_unseen_intent)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.label_embed_size = config.label_embed_size
        self.latent_size = config.latent_size

        self.seed = config.seed
        self.use_ot_label = config.use_ot_label
        self.use_rand_ot_label = config.use_rand_ot_label  # Only valid if use_ot_label is true, whether use all other label
        self.use_rand_fixed_ot_label = config.use_rand_fixed_ot_label  # valid when use_ot_label=true and use_rand_ot_label=true
        if self.use_ot_label:
            self.rand_ot_label_num = config.rand_ot_label_num  # valid when use_ot_label=true and use_rand_ot_label=true
        else:
            self.rand_ot_label_num = self.seen_intent_size - 1

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.labels = tf.placeholder(
                dtype=tf.int32, shape=(None, ),
                name="labels")  # each utterance have a label, [batch_size,]
            self.ot_label_rand = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="ot_labels_rand")
            self.ot_labels_all = tf.placeholder(
                dtype=tf.int32, shape=(None, None),
                name="ot_labels_all")  #(batch_size, len(api.label_vocab)-1)

            # target response given the dialog context
            self.io_tokens = tf.placeholder(dtype=tf.int32,
                                            shape=(None, None),
                                            name="output_tokens")
            self.io_lens = tf.placeholder(dtype=tf.int32,
                                          shape=(None, ),
                                          name="output_lens")
            self.output_labels = tf.placeholder(dtype=tf.int32,
                                                shape=(None, ),
                                                name="output_labels")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(
                dtype=tf.bool, name="use_prior")  # whether use prior
            self.prior_mulogvar = tf.placeholder(
                dtype=tf.float32,
                shape=(None, config.latent_size * 2),
                name="prior_mulogvar")

            self.batch_size = tf.placeholder(dtype=tf.int32, name="batch_size")

        max_out_len = array_ops.shape(self.io_tokens)[1]
        # batch_size = array_ops.shape(self.io_tokens)[0]
        batch_size = self.batch_size

        with variable_scope.variable_scope("labelEmbedding",
                                           reuse=tf.AUTO_REUSE):
            self.la_embedding = tf.get_variable(
                "embedding", [self.seen_intent_size, config.label_embed_size],
                dtype=tf.float32)
            label_embedding = embedding_ops.embedding_lookup(
                self.la_embedding, self.output_labels)  # not use

        with variable_scope.variable_scope("wordEmbedding",
                                           reuse=tf.AUTO_REUSE):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32,
                trainable=False)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask  # boardcast, first row is all 0.

            io_embedding = embedding_ops.embedding_lookup(
                embedding, self.io_tokens)  # 3 dim

            if config.sent_type == "bow":
                io_embedding, _ = get_bow(io_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                io_embedding, _ = get_rnn_encode(io_embedding,
                                                 sent_cell,
                                                 self.io_lens,
                                                 scope="sent_rnn",
                                                 reuse=tf.AUTO_REUSE)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                io_embedding, _ = get_bi_rnn_encode(
                    io_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    self.io_lens,
                    scope="sent_bi_rnn",
                    reuse=tf.AUTO_REUSE
                )  # equal to x of the graph, (batch_size, 300*2)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # print('==========================', io_embedding) # Tensor("models_2/wordEmbedding/sent_bi_rnn/concat:0", shape=(?, 600), dtype=float32)

            # convert label into 1 hot
            my_label_one_hot = tf.one_hot(tf.reshape(self.labels, [-1]),
                                          depth=self.seen_intent_size,
                                          dtype=tf.float32)  # 2 dim
            if config.use_ot_label:
                if config.use_rand_ot_label:
                    ot_label_one_hot = tf.one_hot(tf.reshape(
                        self.ot_label_rand, [-1]),
                                                  depth=self.seen_intent_size,
                                                  dtype=tf.float32)
                    ot_label_one_hot = tf.reshape(
                        ot_label_one_hot,
                        [-1, self.seen_intent_size * self.rand_ot_label_num])
                else:
                    ot_label_one_hot = tf.one_hot(tf.reshape(
                        self.ot_labels_all, [-1]),
                                                  depth=self.seen_intent_size,
                                                  dtype=tf.float32)
                    ot_label_one_hot = tf.reshape(
                        ot_label_one_hot, [
                            -1, self.seen_intent_size *
                            (self.seen_intent_size - 1)
                        ]
                    )  # (batch_size, len(api.label_vocab)*(len(api.label_vocab)-1))

        with variable_scope.variable_scope("recognitionNetwork",
                                           reuse=tf.AUTO_REUSE):
            recog_input = io_embedding
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")  # config.latent_size=200
            recog_mu, recog_logvar = tf.split(
                recog_mulogvar, 2, axis=1
            )  # recognition network output. (batch_size, config.latent_size)

        with variable_scope.variable_scope("priorNetwork",
                                           reuse=tf.AUTO_REUSE):
            # p(xyz) = p(z)p(x|z)p(y|xz)
            # prior network parameter, assum the normal distribution
            # prior_mulogvar = tf.constant([[1] * config.latent_size + [0] * config.latent_size]*batch_size,
            #                              dtype=tf.float32, name="muvar") # can not use by this manner
            prior_mulogvar = self.prior_mulogvar
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,  # bool input
                lambda: sample_gaussian(prior_mu, prior_logvar
                                        ),  # equal to shape(prior_logvar)
                lambda: sample_gaussian(recog_mu, recog_logvar)
            )  # if ... else ..., (batch_size, config.latent_size)
            self.z = latent_sample

        with variable_scope.variable_scope("generationNetwork",
                                           reuse=tf.AUTO_REUSE):
            bow_loss_inputs = latent_sample  # (part of) response network input
            label_inputs = latent_sample
            dec_inputs = latent_sample

            # BOW loss
            if config.use_bow_loss:
                bow_fc1 = layers.fully_connected(
                    bow_loss_inputs,
                    400,
                    activation_fn=tf.tanh,
                    scope="bow_fc1")  # MLPb network fc layer
                # error1:ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.
                if config.keep_prob < 1.0:
                    bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
                self.bow_logits = layers.fully_connected(
                    bow_fc1,
                    self.vocab_size,
                    activation_fn=None,
                    scope="bow_project")  # MLPb network fc output

            # Y loss, include the other y.
            my_label_fc1 = layers.fully_connected(label_inputs,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="my_label_fc1")
            if config.keep_prob < 1.0:
                my_label_fc1 = tf.nn.dropout(my_label_fc1, config.keep_prob)

            # my_label_fc2 = layers.fully_connected(my_label_fc1, 400, activation_fn=tf.tanh, scope="my_label_fc2")
            # if config.keep_prob < 1.0:
            #     my_label_fc2 = tf.nn.dropout(my_label_fc2, config.keep_prob)

            self.my_label_logits = layers.fully_connected(
                my_label_fc1, self.seen_intent_size,
                scope="my_label_project")  # MLPy fc output
            my_label_prob = tf.nn.softmax(
                self.my_label_logits
            )  # softmax output, (batch_size, label_vocab_size)
            self.my_label_prob = my_label_prob
            pred_my_label_embedding = tf.matmul(
                my_label_prob, self.la_embedding
            )  # predicted my label y. (batch_size, label_embed_size)

            if config.use_ot_label:
                if config.use_rand_ot_label:  # use one random other label
                    ot_label_fc1 = layers.fully_connected(
                        label_inputs,
                        400,
                        activation_fn=tf.tanh,
                        scope="ot_label_fc1")
                    if config.keep_prob < 1.0:
                        ot_label_fc1 = tf.nn.dropout(ot_label_fc1,
                                                     config.keep_prob)
                    self.ot_label_logits = layers.fully_connected(
                        ot_label_fc1,
                        self.rand_ot_label_num * self.seen_intent_size,
                        scope="ot_label_rand_project")
                    ot_label_logits_split = tf.reshape(
                        self.ot_label_logits,
                        [-1, self.rand_ot_label_num, self.seen_intent_size])
                    ot_label_prob_short = tf.nn.softmax(ot_label_logits_split)
                    ot_label_prob = tf.reshape(
                        ot_label_prob_short,
                        [-1, self.rand_ot_label_num * self.seen_intent_size]
                    )  # (batch_size, self.rand_ot_label_num*self.label_vocab_size)
                    pred_ot_label_embedding = tf.reshape(
                        tf.matmul(ot_label_prob_short, self.la_embedding),
                        [self.label_embed_size * self.rand_ot_label_num
                         ])  # predicted other label y2.
                else:
                    ot_label_fc1 = layers.fully_connected(
                        label_inputs,
                        400,
                        activation_fn=tf.tanh,
                        scope="ot_label_fc1")
                    if config.keep_prob < 1.0:
                        ot_label_fc1 = tf.nn.dropout(ot_label_fc1,
                                                     config.keep_prob)
                    self.ot_label_logits = layers.fully_connected(
                        ot_label_fc1,
                        self.seen_intent_size * (self.seen_intent_size - 1),
                        scope="ot_label_all_project")
                    ot_label_logits_split = tf.reshape(
                        self.ot_label_logits,
                        [-1, self.seen_intent_size - 1, self.seen_intent_size])
                    ot_label_prob_short = tf.nn.softmax(ot_label_logits_split)
                    ot_label_prob = tf.reshape(
                        ot_label_prob_short, [
                            -1, self.seen_intent_size *
                            (self.seen_intent_size - 1)
                        ]
                    )  # (batch_size, self.label_vocab_size*(self.label_vocab_size-1))
                    pred_ot_label_embedding = tf.reshape(
                        tf.matmul(ot_label_prob_short, self.la_embedding),
                        [self.label_embed_size * (self.seen_intent_size - 1)]
                    )  # predicted other all label y. (batch_size, self.label_embed_size*(self.label_vocab_size-1))
                    # note:matmul can calc (3, 4, 5) × (5, 4) = (3, 4, 4)
            else:  # only use label y.
                self.ot_label_logits = None
                pred_ot_label_embedding = None

            # Decoder, Response Network
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:  # test
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=None)  # a function
                dec_input_embedding = None
                dec_seq_lens = None
            else:  # train
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, None)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.io_tokens
                )  # x 's embedding (batch_size, utt_len, embed_size)
                dec_input_embedding = dec_input_embedding[:, 0:
                                                          -1, :]  # ignore the last </s>
                dec_seq_lens = self.io_lens - 1  # input placeholder

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

                # print("=======", dec_input_embedding) # Tensor("models/decoder/strided_slice:0", shape=(?, ?, 200), dtype=float32)

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens
            )  # dec_outs [batch_size, seq, features]

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(
                    dec_outs, axis=2)))  # get softmax vec's max index
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(
                    dec_outs,
                    2)  # (batch_size, utt_len), each element is index of word

        if not forward:
            with variable_scope.variable_scope("loss", reuse=tf.AUTO_REUSE):
                labels = self.io_tokens[:,
                                        1:]  # not include the first word <s>, (batch_size, utt_len)
                label_mask = tf.to_float(tf.sign(labels))

                labels = tf.one_hot(labels,
                                    depth=self.vocab_size,
                                    dtype=tf.float32)

                print(dec_outs)
                print(labels)
                # Tensor("models_1/decoder/dynamic_rnn_decoder/transpose_1:0", shape=(?, ?, 892), dtype=float32)
                # Tensor("models_1/loss/strided_slice:0", shape=(?, ?), dtype=int32)
                # rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=dec_outs, labels=labels) # response network loss
                rc_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)  # response network loss
                # logits_size=[390,892] labels_size=[1170,892]
                rc_loss = tf.reduce_sum(
                    rc_loss * label_mask,
                    reduction_indices=1)  # (batch_size,), except the word unk
                self.avg_rc_loss = tf.reduce_mean(rc_loss)  # scalar
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                """ as n-trial multimodal distribution. """
                tile_bow_logits = tf.tile(
                    tf.expand_dims(self.bow_logits, 1),
                    [1, max_out_len - 1, 1
                     ])  # (batch_size, max_out_len-1, vocab_size)
                bow_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels
                ) * label_mask  # labels shape less than logits shape, (batch_size, max_out_len-1)
                bow_loss = tf.reduce_sum(bow_loss,
                                         reduction_indices=1)  # (batch_size, )
                self.avg_bow_loss = tf.reduce_mean(bow_loss)  # scalar

                # the label y
                my_label_loss = tf.nn.softmax_cross_entropy_with_logits(
                    logits=my_label_prob,
                    labels=my_label_one_hot)  # label (batch_size,)
                self.avg_my_label_loss = tf.reduce_mean(my_label_loss)
                if config.use_ot_label:
                    ot_label_loss = -tf.nn.softmax_cross_entropy_with_logits(
                        logits=ot_label_prob, labels=ot_label_one_hot)
                    self.avg_ot_label_loss = tf.reduce_mean(ot_label_loss)
                else:
                    self.avg_ot_label_loss = 0.0

                kld = gaussian_kld(
                    recog_mu, recog_logvar, prior_mu,
                    prior_logvar)  # kl divergence, (batch_size,)
                self.avg_kld = tf.reduce_mean(kld)  # scalar
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld  # Restructure loss and kl divergence
                #=====================================================================================================total loss====================================================#
                if config.use_rand_ot_label:
                    aug_elbo = self.avg_bow_loss + 1000 * self.avg_my_label_loss + 10 * self.avg_ot_label_loss + self.elbo  # augmented loss
                    # (1/self.rand_ot_label_num)*
                else:
                    aug_elbo = self.avg_bow_loss + 1000 * self.avg_my_label_loss + 10 * self.avg_ot_label_loss + self.elbo  # augmented loss
                    # (1/(self.label_vocab_size-1))*

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)
                tf.summary.scalar("my_label_loss", self.avg_my_label_loss)
                tf.summary.scalar("ot_label_loss", self.avg_ot_label_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)  # probability
                self.log_q_z_xy = norm_log_liklihood(
                    latent_sample, recog_mu, recog_logvar)  # probability
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
        print('model establish finish!')
Esempio n. 3
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.topic_vocab = api.topic_vocab
        self.topic_vocab_size = len(self.topic_vocab)

        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.bow_weights = config.bow_weights

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None,
                                                        self.max_utt_len),
                                                 name="context")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_topics = tf.placeholder(dtype=tf.int32,
                                                shape=(None, ),
                                                name="output_topic")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_context_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        if config.use_hcf:
            with variable_scope.variable_scope("topicEmbedding"):
                t_embedding = tf.get_variable(
                    "embedding",
                    [self.topic_vocab_size, config.topic_embed_size],
                    dtype=tf.float32)
                topic_embedding = embedding_ops.embedding_lookup(
                    t_embedding, self.output_topics)

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            # context nn
            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     self.output_lens,
                                                     scope="sent_rnn",
                                                     reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_context_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                enc_last_state = tf.concat(enc_last_state, 1)

        # combine with other attributes
        if config.use_hcf:
            attribute_embedding = topic_embedding
            attribute_fc1 = layers.fully_connected(attribute_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

        cond_embedding = enc_last_state

        with variable_scope.variable_scope("recognitionNetwork"):
            if config.use_hcf:
                recog_input = tf.concat(
                    [cond_embedding, output_embedding, attribute_fc1], 1)
            else:
                recog_input = tf.concat([cond_embedding, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Y loss
            if config.use_hcf:
                meta_fc1 = layers.fully_connected(latent_sample,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="meta_fc1")
                if config.keep_prob < 1.0:
                    meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
                self.topic_logits = layers.fully_connected(
                    meta_fc1, self.topic_vocab_size, scope="topic_project")
                topic_prob = tf.nn.softmax(self.topic_logits)
                #pred_attribute_embedding = tf.matmul(topic_prob, t_embedding)
                pred_topic = tf.argmax(topic_prob, 1)
                pred_attribute_embedding = embedding_ops.embedding_lookup(
                    t_embedding, pred_topic)
                if forward:
                    selected_attribute_embedding = pred_attribute_embedding
                else:
                    selected_attribute_embedding = attribute_embedding
                dec_inputs = tf.concat(
                    [gen_inputs, selected_attribute_embedding], 1)
            else:
                self.topic_logits = tf.zeros(
                    (batch_size, self.topic_vocab_size))
                selected_attribute_embedding = None
                dec_inputs = gen_inputs

            # Decoder
            if config.num_layer > 1:
                dec_init_state = [
                    layers.fully_connected(dec_inputs,
                                           self.dec_cell_size,
                                           activation_fn=None,
                                           scope="init_state-%d" % i)
                    for i in range(config.num_layer)
                ]
                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens)
            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                """ as n-trial multimodal distribution. """
                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)
                bow_weights = tf.to_float(self.bow_weights)

                # reconstruct the meta info about X
                if config.use_hcf:
                    topic_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=self.topic_logits, labels=self.output_topics)
                    self.avg_topic_loss = tf.reduce_mean(topic_loss)
                else:
                    self.avg_topic_loss = 0.0

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = bow_weights * self.avg_bow_loss + self.avg_topic_loss + self.elbo

                tf.summary.scalar("topic_loss", self.avg_topic_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
Esempio n. 4
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab  # index2word
        self.rev_vocab = api.rev_vocab  # word2index
        self.vocab_size = len(self.vocab)  # vocab size
        self.emotion_vocab = api.emotion_vocab  # index2emotion
        self.emotion_vocab_size = len(self.emotion_vocab)
        # self.da_vocab = api.dialog_act_vocab
        # self.da_vocab_size = len(self.da_vocab)

        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]

        self.context_cell_size = config.cxt_cell_size  # dont need
        self.sent_cell_size = config.sent_cell_size  # for encode
        self.dec_cell_size = config.dec_cell_size  # for decode

        with tf.name_scope("io"):
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None,
                                                        self.max_utt_len),
                                                 name="input_contexts")
            # self.floors = tf.placeholder(dtype=tf.int32, shape=(None, None), name="floor")
            self.input_lens = tf.placeholder(dtype=tf.int32,
                                             shape=(None, ),
                                             name="input_lens")
            self.input_emotions = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, ),
                                                 name="input_emotions")
            # self.my_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="my_profile")
            # self.ot_profile = tf.placeholder(dtype=tf.float32, shape=(None, 4), name="ot_profile")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_emotions = tf.placeholder(dtype=tf.int32,
                                                  shape=(None, ),
                                                  name="output_emotions")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_dialog_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("emotionEmbedding"):
            t_embedding = tf.get_variable(
                "embedding",
                [self.emotion_vocab_size, config.topic_embed_size],
                dtype=tf.float32)
            inp_emotion_embedding = embedding_ops.embedding_lookup(
                t_embedding, self.input_emotions)
            outp_emotion_embedding = embedding_ops.embedding_lookup(
                t_embedding, self.output_emotions)

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            if config.sent_type == "rnn":
                enc_cell = self.get_rnncell(config.cell_type,
                                            self.context_cell_size,
                                            keep_prob=1.0,
                                            num_layer=config.num_layer)
                _, enc_last_state = tf.nn.dynamic_rnn(
                    enc_cell,
                    input_embedding,
                    dtype=tf.float32,
                    sequence_length=self.input_lens)

                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                # input_embedding, sent_size = get_rnn_encode(input_embedding, sent_cell, scope="sent_rnn")

                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        sent_cell,
                                                        self.output_lens,
                                                        scope="sent_rnn")

            elif config.sent_type == "bi_rnn":
                fwd_enc_cell = self.get_rnncell(config.cell_type,
                                                self.context_cell_size,
                                                keep_prob=1.0,
                                                num_layer=config.num_layer)
                bwd_enc_cell = self.get_rnncell(config.cell_type,
                                                self.context_cell_size,
                                                keep_prob=1.0,
                                                num_layer=config.num_layer)
                _, enc_last_state = tf.nn.bidirectional_dynamic_rnn(
                    fwd_enc_cell,
                    bwd_enc_cell,
                    input_embedding,
                    dtype=tf.float32,
                    sequence_length=self.input_lens)
                enc_last_state = enc_last_state[0] + enc_last_state[1]

                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                # input_embedding, sent_size = get_bi_rnn_encode(input_embedding, fwd_sent_cell, bwd_sent_cell, scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn")

            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [rnn, bi_rnn]")

            # reshape input into dialogs
            # input_embedding = tf.reshape(input_embedding, [-1, max_dialog_len, sent_size])
            # if config.keep_prob < 1.0:
            #     input_embedding = tf.nn.dropout(input_embedding, config.keep_prob)

            # convert floors into 1 hot
            # floor_one_hot = tf.one_hot(tf.reshape(self.floors, [-1]), depth=2, dtype=tf.float32)
            # floor_one_hot = tf.reshape(floor_one_hot, [-1, max_dialog_len, 2])

            # joint_embedding = tf.concat([input_embedding, floor_one_hot], 2, "joint_embedding")

        with variable_scope.variable_scope("contextRNN"):

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

            attribute_fc1 = layers.fully_connected(outp_emotion_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

            cond_embedding = tf.concat([inp_emotion_embedding, enc_last_state],
                                       1)

        with variable_scope.variable_scope("recognitionNetwork"):
            recog_input = tf.concat(
                [cond_embedding, output_embedding, attribute_fc1], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample], 1)

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Y loss
            meta_fc1 = layers.fully_connected(gen_inputs,
                                              400,
                                              activation_fn=tf.tanh,
                                              scope="meta_fc1")
            if config.keep_prob < 1.0:
                meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
            self.da_logits = layers.fully_connected(meta_fc1,
                                                    self.emotion_vocab_size,
                                                    scope="da_project")
            da_prob = tf.nn.softmax(self.da_logits)
            pred_attribute_embedding = tf.matmul(da_prob, t_embedding)
            if forward:
                selected_attribute_embedding = pred_attribute_embedding
            else:
                selected_attribute_embedding = outp_emotion_embedding
            dec_inputs = tf.concat([gen_inputs, selected_attribute_embedding],
                                   1)

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):
                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))

                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)

                da_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=self.da_logits, labels=self.output_emotions)
                self.avg_da_loss = tf.reduce_mean(da_loss)

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo

                tf.summary.scalar("da_loss", self.avg_da_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
Esempio n. 5
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size
        self.num_topics = config.num_topics

        with tf.name_scope("io"):
            # all dialog context and known attributes
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None,
                                                        self.max_utt_len),
                                                 name="dialog_context")
            self.floors = tf.placeholder(dtype=tf.float32,
                                         shape=(None, None),
                                         name="floor")  # TODO float
            self.floor_labels = tf.placeholder(dtype=tf.float32,
                                               shape=(None, 1),
                                               name="floor_labels")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")
            self.paragraph_topics = tf.placeholder(dtype=tf.float32,
                                                   shape=(None,
                                                          self.num_topics),
                                                   name="paragraph_topics")

            # target response given the dialog context
            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")
            self.output_das = tf.placeholder(dtype=tf.float32,
                                             shape=(None, self.num_topics),
                                             name="output_dialog_acts")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_dialog_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            # embed the input
            input_embedding = embedding_ops.embedding_lookup(
                embedding, tf.reshape(self.input_contexts, [-1]))
            # reshape embedding. -1 means that the first dimension can be whatever necessary to make the other 2 dimensions work w/the data
            input_embedding = tf.reshape(
                input_embedding, [-1, self.max_utt_len, config.embed_size])
            # embed the output so you can feed it into the VAE
            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            #
            if config.sent_type == "bow":
                input_embedding, sent_size = get_bow(input_embedding)
                output_embedding, _ = get_bow(output_embedding)

            elif config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     self.output_lens,
                                                     scope="sent_rnn",
                                                     reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            input_embedding = tf.reshape(input_embedding,
                                         [-1, max_dialog_len, sent_size])
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

            # reshape floors
            floor = tf.reshape(self.floors, [-1, max_dialog_len, 1])

            joint_embedding = tf.concat([input_embedding, floor], 2,
                                        "joint_embedding")

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                joint_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

        # combine with other attributes
        if config.use_hcf:
            # TODO is this reshape ok?
            attribute_embedding = tf.reshape(
                self.output_das, [-1, self.num_topics])  # da_embedding
            attribute_fc1 = layers.fully_connected(attribute_embedding,
                                                   30,
                                                   activation_fn=tf.tanh,
                                                   scope="attribute_fc1")

        # conditions include topic and rnn of all previous birnn results and metadata about the two people
        cond_list = [self.paragraph_topics, enc_last_state]
        cond_embedding = tf.concat(cond_list, 1)  #float32

        with variable_scope.variable_scope("recognitionNetwork"):
            if config.use_hcf:
                recog_input = tf.concat(
                    [cond_embedding, output_embedding, attribute_fc1], 1)
            else:
                recog_input = tf.concat([cond_embedding, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            # mu and logvar are both vectors of size latent_size
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            # P(XYZ)=P(Z|X)P(X)P(Y|X,Z)
            prior_fc1 = layers.fully_connected(cond_embedding,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([cond_embedding, latent_sample],
                                   1)  #float32

            # BOW loss
            bow_fc1 = layers.fully_connected(gen_inputs,
                                             400,
                                             activation_fn=tf.tanh,
                                             scope="bow_fc1")
            if config.keep_prob < 1.0:
                bow_fc1 = tf.nn.dropout(bow_fc1, config.keep_prob)
            self.bow_logits = layers.fully_connected(bow_fc1,
                                                     self.vocab_size,
                                                     activation_fn=None,
                                                     scope="bow_project")

            # Predicting Y (topic)
            if config.use_hcf:
                meta_fc1 = layers.fully_connected(gen_inputs,
                                                  400,
                                                  activation_fn=tf.tanh,
                                                  scope="meta_fc1")
                if config.keep_prob < 1.0:
                    meta_fc1 = tf.nn.dropout(meta_fc1, config.keep_prob)
                self.da_logits = layers.fully_connected(
                    meta_fc1, self.num_topics, scope="da_project")  # float32

                da_prob = tf.nn.softmax(self.da_logits)
                pred_attribute_embedding = da_prob  # TODO change the name of this to predicted sentence topic
                # pred_attribute_embedding = tf.matmul(da_prob, d_embedding)

                if forward:
                    selected_attribute_embedding = pred_attribute_embedding
                else:
                    selected_attribute_embedding = attribute_embedding
                dec_inputs = tf.concat(
                    [gen_inputs, selected_attribute_embedding], 1)

            # if use_hcf not on, the model won't predict the Y
            else:
                self.da_logits = tf.zeros((batch_size, self.num_topics))
                dec_inputs = gen_inputs
                selected_attribute_embedding = None

            # Predicting whether or not end of paragraph
            self.paragraph_end_logits = layers.fully_connected(
                gen_inputs,
                1,
                activation_fn=tf.tanh,
                scope="paragraph_end_fc1")  # float32

            # Decoder
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        # initializer thing for lstm
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            # projects into thing of vocab size. TODO no softmax?
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                dec_input_embedding = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    # get make of keep/throw-away
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens,
                name='output_node')

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):

                labels = self.output_tokens[:, 1:]  # correct word tokens
                label_mask = tf.to_float(tf.sign(labels))

                # Loss between words
                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)
                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))

                # BOW loss
                tile_bow_logits = tf.tile(tf.expand_dims(self.bow_logits, 1),
                                          [1, max_out_len - 1, 1])
                bow_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=tile_bow_logits, labels=labels) * label_mask
                bow_loss = tf.reduce_sum(bow_loss, reduction_indices=1)
                self.avg_bow_loss = tf.reduce_mean(bow_loss)

                # Predict 0/1 (1 = last sentence in paragraph)
                end_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.floor_labels, logits=self.paragraph_end_logits)
                self.avg_end_loss = tf.reduce_mean(end_loss)

                # Topic prediction loss
                if config.use_hcf:
                    div_prob = tf.divide(self.da_logits, self.output_das)
                    self.avg_da_loss = tf.reduce_mean(
                        -tf.nn.softmax_cross_entropy_with_logits(
                            logits=self.da_logits, labels=div_prob))

                else:
                    self.avg_da_loss = 0.0

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld
                aug_elbo = self.avg_bow_loss + self.avg_da_loss + self.elbo + self.avg_end_loss

                tf.summary.scalar("da_loss", self.avg_da_loss)
                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)
                tf.summary.scalar("bow_loss", self.avg_bow_loss)
                tf.summary.scalar("paragraph_end_loss", self.avg_end_loss)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss + bow_loss -
                                                   self.log_p_z +
                                                   self.log_q_z_xy)

            self.optimize(sess, config, aug_elbo, log_dir)

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)
Esempio n. 6
0
    def __init__(self, sess, config, api, log_dir, forward, scope=None):
        # self.self_label = tf.placeholder(dtype=tf.bool,shape=(None), name="self_label")
        self.self_label = False
        self.vocab = api.vocab
        self.rev_vocab = api.rev_vocab
        self.vocab_size = len(self.vocab)
        self.sess = sess
        self.scope = scope
        self.max_utt_len = config.max_utt_len
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.context_cell_size = config.cxt_cell_size
        self.sent_cell_size = config.sent_cell_size
        self.dec_cell_size = config.dec_cell_size

        with tf.name_scope("io"):
            self.input_contexts = tf.placeholder(dtype=tf.int32,
                                                 shape=(None, None),
                                                 name="dialog_context")
            self.context_lens = tf.placeholder(dtype=tf.int32,
                                               shape=(None, ),
                                               name="context_lens")

            self.output_tokens = tf.placeholder(dtype=tf.int32,
                                                shape=(None, None),
                                                name="output_token")
            self.output_lens = tf.placeholder(dtype=tf.int32,
                                              shape=(None, ),
                                              name="output_lens")

            # optimization related variables
            self.learning_rate = tf.Variable(float(config.init_lr),
                                             trainable=False,
                                             name="learning_rate")
            self.learning_rate_decay_op = self.learning_rate.assign(
                tf.multiply(self.learning_rate, config.lr_decay))
            self.global_t = tf.placeholder(dtype=tf.int32, name="global_t")
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

        max_input_len = array_ops.shape(self.input_contexts)[1]
        max_out_len = array_ops.shape(self.output_tokens)[1]
        batch_size = array_ops.shape(self.input_contexts)[0]

        with variable_scope.variable_scope("wordEmbedding"):
            self.embedding = tf.get_variable(
                "embedding", [self.vocab_size, config.embed_size],
                dtype=tf.float32)
            embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            embedding = self.embedding * embedding_mask

            input_embedding = embedding_ops.embedding_lookup(
                embedding, self.input_contexts)

            output_embedding = embedding_ops.embedding_lookup(
                embedding, self.output_tokens)

            if config.sent_type == "rnn":
                sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                             config.keep_prob, 1)
                input_embedding, sent_size = get_rnn_encode(input_embedding,
                                                            sent_cell,
                                                            scope="sent_rnn")
                output_embedding, _ = get_rnn_encode(output_embedding,
                                                     sent_cell,
                                                     self.output_lens,
                                                     scope="sent_rnn",
                                                     reuse=True)
            elif config.sent_type == "bi_rnn":
                fwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                bwd_sent_cell = self.get_rnncell("gru",
                                                 self.sent_cell_size,
                                                 keep_prob=1.0,
                                                 num_layer=1)
                input_embedding, sent_size = get_bi_rnn_encode(
                    input_embedding,
                    fwd_sent_cell,
                    bwd_sent_cell,
                    self.context_lens,
                    scope="sent_bi_rnn")
                output_embedding, _ = get_bi_rnn_encode(output_embedding,
                                                        fwd_sent_cell,
                                                        bwd_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_bi_rnn",
                                                        reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            if config.keep_prob < 1.0:
                input_embedding = tf.nn.dropout(input_embedding,
                                                config.keep_prob)

        with variable_scope.variable_scope("contextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            input_embedding = tf.expand_dims(input_embedding, axis=2)
            _, enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    enc_last_state = [temp.h for temp in enc_last_state]

                enc_last_state = tf.concat(enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    enc_last_state = enc_last_state.h

        # input [enc_last_state, output_embedding] -- [c, x] --->z
        with variable_scope.variable_scope("recognitionNetwork"):
            recog_input = tf.concat([enc_last_state, output_embedding], 1)
            self.recog_mulogvar = recog_mulogvar = layers.fully_connected(
                recog_input,
                config.latent_size * 2,
                activation_fn=None,
                scope="muvar")
            recog_mu, recog_logvar = tf.split(recog_mulogvar, 2, axis=1)

        with variable_scope.variable_scope("priorNetwork"):
            # P(XYZ)=P(Z|X)P(X)P(Y|X,Z)
            prior_fc1 = layers.fully_connected(enc_last_state,
                                               np.maximum(
                                                   config.latent_size * 2,
                                                   100),
                                               activation_fn=tf.tanh,
                                               scope="fc1")
            prior_mulogvar = layers.fully_connected(prior_fc1,
                                                    config.latent_size * 2,
                                                    activation_fn=None,
                                                    scope="muvar")
            prior_mu, prior_logvar = tf.split(prior_mulogvar, 2, axis=1)

            # use sampled Z or posterior Z
            latent_sample = tf.cond(
                self.use_prior,
                lambda: sample_gaussian(prior_mu, prior_logvar),
                lambda: sample_gaussian(recog_mu, recog_logvar))

        with variable_scope.variable_scope("label_encoder"):
            le_embedding_mask = tf.constant(
                [0 if i == 0 else 1 for i in range(self.vocab_size)],
                dtype=tf.float32,
                shape=[self.vocab_size, 1])
            le_embedding = self.embedding * le_embedding_mask

            le_input_embedding = embedding_ops.embedding_lookup(
                le_embedding, self.input_contexts)

            le_output_embedding = embedding_ops.embedding_lookup(
                le_embedding, self.output_tokens)

            if config.sent_type == "rnn":
                le_sent_cell = self.get_rnncell("gru", self.sent_cell_size,
                                                config.keep_prob, 1)
                le_input_embedding, le_sent_size = get_rnn_encode(
                    le_input_embedding, le_sent_cell, scope="sent_rnn")
                le_output_embedding, _ = get_rnn_encode(le_output_embedding,
                                                        le_sent_cell,
                                                        self.output_lens,
                                                        scope="sent_rnn",
                                                        reuse=True)
            elif config.sent_type == "bi_rnn":
                le_fwd_sent_cell = self.get_rnncell("gru",
                                                    self.sent_cell_size,
                                                    keep_prob=1.0,
                                                    num_layer=1)
                le_bwd_sent_cell = self.get_rnncell("gru",
                                                    self.sent_cell_size,
                                                    keep_prob=1.0,
                                                    num_layer=1)
                le_input_embedding, le_sent_size = get_bi_rnn_encode(
                    le_input_embedding,
                    le_fwd_sent_cell,
                    le_bwd_sent_cell,
                    self.context_lens,
                    scope="sent_bi_rnn")
                le_output_embedding, _ = get_bi_rnn_encode(le_output_embedding,
                                                           le_fwd_sent_cell,
                                                           le_bwd_sent_cell,
                                                           self.output_lens,
                                                           scope="sent_bi_rnn",
                                                           reuse=True)
            else:
                raise ValueError(
                    "Unknown sent_type. Must be one of [bow, rnn, bi_rnn]")

            # reshape input into dialogs
            if config.keep_prob < 1.0:
                le_input_embedding = tf.nn.dropout(le_input_embedding,
                                                   config.keep_prob)

        # [le_enc_last_state, le_output_embedding]
        with variable_scope.variable_scope("lecontextRNN"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        self.context_cell_size,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            le_input_embedding = tf.expand_dims(le_input_embedding, axis=2)
            _, le_enc_last_state = tf.nn.dynamic_rnn(
                enc_cell,
                le_input_embedding,
                dtype=tf.float32,
                sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    le_enc_last_state = [temp.h for temp in le_enc_last_state]

                le_enc_last_state = tf.concat(le_enc_last_state, 1)
            else:
                if config.cell_type == 'lstm':
                    le_enc_last_state = le_enc_last_state.h
            best_en = tf.concat([le_enc_last_state, le_output_embedding], 1)

        with variable_scope.variable_scope("ggammaNet"):
            enc_cell = self.get_rnncell(config.cell_type,
                                        200,
                                        keep_prob=1.0,
                                        num_layer=config.num_layer)
            # and enc_last_state will be same as the true last state
            input_embedding = tf.expand_dims(best_en, axis=2)
            _, zlabel = tf.nn.dynamic_rnn(enc_cell,
                                          input_embedding,
                                          dtype=tf.float32,
                                          sequence_length=self.context_lens)

            if config.num_layer > 1:
                if config.cell_type == 'lstm':
                    zlabel = [temp.h for temp in enc_last_state]

                zlabel = tf.concat(zlabel, 1)
            else:
                if config.cell_type == 'lstm':
                    zlabel = zlabel.h

        with variable_scope.variable_scope("generationNetwork"):
            gen_inputs = tf.concat([enc_last_state, latent_sample], 1)

            dec_inputs = gen_inputs
            selected_attribute_embedding = None

            # Decoder_init_state
            if config.num_layer > 1:
                dec_init_state = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state.append(temp_init)

                dec_init_state = tuple(dec_init_state)
            else:
                dec_init_state = layers.fully_connected(dec_inputs,
                                                        self.dec_cell_size,
                                                        activation_fn=None,
                                                        scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state = rnn_cell.LSTMStateTuple(
                        dec_init_state, dec_init_state)

        with variable_scope.variable_scope("generationNetwork1"):
            gen_inputs_sl = tf.concat([le_enc_last_state, zlabel], 1)

            dec_inputs_sl = gen_inputs_sl
            selected_attribute_embedding = None

            # Decoder_init_state
            if config.num_layer > 1:
                dec_init_state_sl = []
                for i in range(config.num_layer):
                    temp_init = layers.fully_connected(dec_inputs_sl,
                                                       self.dec_cell_size,
                                                       activation_fn=None,
                                                       scope="init_state-%d" %
                                                       i)
                    if config.cell_type == 'lstm':
                        temp_init = rnn_cell.LSTMStateTuple(
                            temp_init, temp_init)

                    dec_init_state_sl.append(temp_init)

                dec_init_state_sl = tuple(dec_init_state_sl)
            else:
                dec_init_state_sl = layers.fully_connected(dec_inputs_sl,
                                                           self.dec_cell_size,
                                                           activation_fn=None,
                                                           scope="init_state")
                if config.cell_type == 'lstm':
                    dec_init_state_sl = rnn_cell.LSTMStateTuple(
                        dec_init_state_sl, dec_init_state_sl)

        with variable_scope.variable_scope("decoder"):
            dec_cell = self.get_rnncell(config.cell_type, self.dec_cell_size,
                                        config.keep_prob, config.num_layer)
            dec_cell = OutputProjectionWrapper(dec_cell, self.vocab_size)

            if forward:
                loop_func = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state,
                    embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)
                loop_func_sl = decoder_fn_lib.context_decoder_fn_inference(
                    None,
                    dec_init_state_sl,
                    le_embedding,
                    start_of_sequence_id=self.go_id,
                    end_of_sequence_id=self.eos_id,
                    maximum_length=self.max_utt_len,
                    num_decoder_symbols=self.vocab_size,
                    context_vector=selected_attribute_embedding)

                dec_input_embedding = None
                dec_input_embedding_sl = None
                dec_seq_lens = None
            else:
                loop_func = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state, selected_attribute_embedding)
                loop_func_sl = decoder_fn_lib.context_decoder_fn_train(
                    dec_init_state_sl, selected_attribute_embedding)

                dec_input_embedding = embedding_ops.embedding_lookup(
                    embedding, self.output_tokens)
                dec_input_embedding_sl = embedding_ops.embedding_lookup(
                    le_embedding, self.output_tokens)

                dec_input_embedding = dec_input_embedding[:, 0:-1, :]
                dec_input_embedding_sl = dec_input_embedding_sl[:, 0:-1, :]

                dec_seq_lens = self.output_lens - 1

                if config.keep_prob < 1.0:
                    dec_input_embedding = tf.nn.dropout(
                        dec_input_embedding, config.keep_prob)
                    dec_input_embedding_sl = tf.nn.dropout(
                        dec_input_embedding_sl, config.keep_prob)

                # apply word dropping. Set dropped word to 0
                if config.dec_keep_prob < 1.0:
                    keep_mask = tf.less_equal(
                        tf.random_uniform((batch_size, max_out_len - 1),
                                          minval=0.0,
                                          maxval=1.0), config.dec_keep_prob)
                    keep_mask = tf.expand_dims(tf.to_float(keep_mask), 2)
                    dec_input_embedding = dec_input_embedding * keep_mask
                    dec_input_embedding_sl = dec_input_embedding_sl * keep_mask
                    dec_input_embedding = tf.reshape(
                        dec_input_embedding,
                        [-1, max_out_len - 1, config.embed_size])
                    dec_input_embedding_sl = tf.reshape(
                        dec_input_embedding_sl,
                        [-1, max_out_len - 1, config.embed_size])

            dec_outs, _, final_context_state = dynamic_rnn_decoder(
                dec_cell,
                loop_func,
                inputs=dec_input_embedding,
                sequence_length=dec_seq_lens)

            dec_outs_sl, _, final_context_state_sl = dynamic_rnn_decoder(
                dec_cell,
                loop_func_sl,
                inputs=dec_input_embedding_sl,
                sequence_length=dec_seq_lens)

            if final_context_state is not None:
                final_context_state = final_context_state[:, 0:array_ops.
                                                          shape(dec_outs)[1]]
                mask = tf.to_int32(tf.sign(tf.reduce_max(dec_outs, axis=2)))
                self.dec_out_words = tf.multiply(
                    tf.reverse(final_context_state, axis=[1]), mask)
            else:
                self.dec_out_words = tf.argmax(dec_outs, 2)

            if final_context_state_sl is not None:
                final_context_state_sl = final_context_state_sl[:, 0:array_ops.
                                                                shape(
                                                                    dec_outs_sl
                                                                )[1]]
                mask_sl = tf.to_int32(
                    tf.sign(tf.reduce_max(dec_outs_sl, axis=2)))
                self.dec_out_words_sl = tf.multiply(
                    tf.reverse(final_context_state_sl, axis=[1]), mask_sl)
            else:
                self.dec_out_words_sl = tf.argmax(dec_outs_sl, 2)

        if not forward:
            with variable_scope.variable_scope("loss"):

                labels = self.output_tokens[:, 1:]
                label_mask = tf.to_float(tf.sign(labels))

                rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs, labels=labels)

                rc_loss = tf.reduce_sum(rc_loss * label_mask,
                                        reduction_indices=1)
                self.avg_rc_loss = tf.reduce_mean(rc_loss)

                sl_rc_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_outs_sl, labels=labels)
                sl_rc_loss = tf.reduce_sum(sl_rc_loss * label_mask,
                                           reduction_indices=1)
                self.sl_rc_loss = tf.reduce_mean(sl_rc_loss)
                # used only for perpliexty calculation. Not used for optimzation
                self.rc_ppl = tf.exp(
                    tf.reduce_sum(rc_loss) / tf.reduce_sum(label_mask))
                """ as n-trial multimodal distribution. """

                kld = gaussian_kld(recog_mu, recog_logvar, prior_mu,
                                   prior_logvar)
                self.avg_kld = tf.reduce_mean(kld)
                if log_dir is not None:
                    kl_weights = tf.minimum(
                        tf.to_float(self.global_t) / config.full_kl_step, 1.0)
                else:
                    kl_weights = tf.constant(1.0)

                self.label_loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=latent_sample, logits=zlabel))

                self.kl_w = kl_weights
                self.elbo = self.avg_rc_loss + kl_weights * self.avg_kld

                self.cvae_loss = self.elbo + +0.1 * self.label_loss
                self.sl_loss = self.sl_rc_loss

                tf.summary.scalar("rc_loss", self.avg_rc_loss)
                tf.summary.scalar("elbo", self.elbo)
                tf.summary.scalar("kld", self.avg_kld)

                self.summary_op = tf.summary.merge_all()

                self.log_p_z = norm_log_liklihood(latent_sample, prior_mu,
                                                  prior_logvar)
                self.log_q_z_xy = norm_log_liklihood(latent_sample, recog_mu,
                                                     recog_logvar)
                self.est_marginal = tf.reduce_mean(rc_loss - self.log_p_z +
                                                   self.log_q_z_xy)

            self.train_sl_ops = self.optimize(sess,
                                              config,
                                              self.sl_loss,
                                              log_dir,
                                              scope="SL")
            self.train_ops = self.optimize(sess,
                                           config,
                                           self.cvae_loss,
                                           log_dir,
                                           scope="CVAE")

        self.saver = tf.train.Saver(tf.global_variables(),
                                    write_version=tf.train.SaverDef.V2)