Beispiel #1
0
    def __init__(self, config, num_words, device='', reuse=False):
        AbstractNetwork.__init__(self, "oracle", device=device)

        with tf.variable_scope(self.scope_name, reuse=reuse):
            embeddings = []
            self.batch_size = None

            # QUESTION
            self._is_training = tf.placeholder(tf.bool, name="is_training")
            self._question = tf.placeholder(tf.int32, [self.batch_size, None],
                                            name='question')
            self._seq_length = tf.placeholder(tf.int32, [self.batch_size],
                                              name='seq_length')

            word_emb = utils.get_embedding(
                self._question,
                n_words=num_words,
                n_dim=int(config['model']['question']["embedding_dim"]),
                scope="word_embedding")

            lstm_states, _ = rnn.variable_length_LSTM(
                word_emb,
                num_hidden=int(config['model']['question']["no_LSTM_hiddens"]),
                seq_length=self._seq_length)
            embeddings.append(lstm_states)

            # CATEGORY
            if config['inputs']['category']:
                self._category = tf.placeholder(tf.int32, [self.batch_size],
                                                name='category')

                cat_emb = utils.get_embedding(
                    self._category,
                    int(config['model']['category']["n_categories"]) +
                    1,  # we add the unkwon category
                    int(config['model']['category']["embedding_dim"]),
                    scope="cat_embedding")
                embeddings.append(cat_emb)
                print("Input: Category")

            # SPATIAL
            if config['inputs']['spatial']:
                self._spatial = tf.placeholder(tf.float32,
                                               [self.batch_size, 8],
                                               name='spatial')
                embeddings.append(self._spatial)
                print("Input: Spatial")

            # IMAGE
            if config['inputs']['image']:
                self._image = tf.placeholder(tf.float32, [self.batch_size] +
                                             config['model']['image']["dim"],
                                             name='image')

                if len(config['model']["image"]["dim"]) == 1:
                    self.image_out = self._image
                else:
                    self.image_out = attention.attention_factory(
                        self._image, lstm_states,
                        config['model']["image"]["attention"])

                embeddings.append(self.image_out)
                print("Input: Image")

            # CROP
            if config['inputs']['crop']:
                self._crop = tf.placeholder(tf.float32, [self.batch_size] +
                                            config['model']['crop']["dim"],
                                            name='crop')

                if len(config['model']["crop"]["dim"]) == 1:
                    self.crop_out = self._crop
                else:
                    self.crop_out = attention.attention_factory(
                        self._crop, lstm_states,
                        config['model']['crop']["attention"])

                embeddings.append(self.crop_out)
                print("Input: Crop")

            # Compute the final embedding
            emb = tf.concat(embeddings, axis=1)

            # OUTPUT
            num_classes = 3
            self._answer = tf.placeholder(tf.float32,
                                          [self.batch_size, num_classes],
                                          name='answer')

            with tf.variable_scope('mlp'):
                num_hiddens = config['model']['MLP']['num_hiddens']
                l1 = utils.fully_connected(emb,
                                           num_hiddens,
                                           activation='relu',
                                           scope='l1')

                self.pred = utils.fully_connected(l1,
                                                  num_classes,
                                                  activation='softmax',
                                                  scope='softmax')
                self.best_pred = tf.argmax(self.pred, axis=1)

            self.loss = tf.reduce_mean(
                utils.cross_entropy(self.pred, self._answer))
            self.error = tf.reduce_mean(utils.error(self.pred, self._answer))

            print('Model... Oracle build!')
    def __init__(self, config, num_words, policy_gradient, device='', reuse=False):
        AbstractNetwork.__init__(self, "qgen", device=device)

        # Create the scope for this graph
        with tf.variable_scope(self.scope_name, reuse=reuse):

            mini_batch_size = None

            # Picture
            self.images = tf.placeholder(tf.float32, [mini_batch_size] + config['image']["dim"], name='images')

            # Question
            self.dialogues = tf.placeholder(tf.int32, [mini_batch_size, None], name='dialogues')
            self.answer_mask = tf.placeholder(tf.float32, [mini_batch_size, None], name='answer_mask')  # 1 if keep and (1 q/a 1) for (START q/a STOP)
            self.padding_mask = tf.placeholder(tf.float32, [mini_batch_size, None], name='padding_mask')
            self.seq_length = tf.placeholder(tf.int32, [mini_batch_size], name='seq_length')

            # Rewards
            self.cum_rewards = tf.placeholder(tf.float32, shape=[mini_batch_size, None], name='cum_reward')

            # DECODER Hidden state (for beam search)
            zero_state = tf.zeros([1, config['num_lstm_units']])  # default LSTM state is a zero-vector
            zero_state = tf.tile(zero_state, [tf.shape(self.images)[0], 1])  # trick to do a dynamic size 0 tensors

            self.decoder_zero_state_c = tf.placeholder_with_default(zero_state, [mini_batch_size, config['num_lstm_units']], name="state_c")
            self.decoder_zero_state_h = tf.placeholder_with_default(zero_state, [mini_batch_size, config['num_lstm_units']], name="state_h")
            decoder_initial_state = tf.contrib.rnn.LSTMStateTuple(c=self.decoder_zero_state_c, h=self.decoder_zero_state_h)

            # Misc
            self.is_training = tf.placeholder(tf.bool, name='is_training')
            self.greedy = tf.placeholder_with_default(False, shape=(), name="greedy") # use for graph
            self.samples = None

            # remove last token
            input_dialogues = self.dialogues[:, :-1]
            input_seq_length = self.seq_length - 1

            # remove first token(=start token)
            rewards = self.cum_rewards[:, 1:]
            target_words = self.dialogues[:, 1:]

            # to understand the padding:
            # input
            #   <start>  is   it   a    blue   <?>   <yes>   is   it  a    car  <?>   <no>   <stop_dialogue>
            # target
            #    is      it   a   blue   <?>    -      is    it   a   car  <?>   -   <stop_dialogue>  -



            # image processing
            if len(config["image"]["dim"]) == 1:
                self.image_out = self.images
            else:
                self.image_out = get_attention(self.images, None, "none") #TODO: improve by using the previous lstm state?


            # Reduce the embedding size of the image
            with tf.variable_scope('picture_embedding'):
                self.picture_emb = utils.fully_connected(self.image_out,
                                                    config['picture_embedding_size'])
                picture_emb = tf.expand_dims(self.picture_emb, 1)
                picture_emb = tf.tile(picture_emb, [1, tf.shape(input_dialogues)[1], 1])

            # Compute the question embedding
            input_words = utils.get_embedding(
                input_dialogues,
                n_words=num_words,
                n_dim=config['word_embedding_size'],
                scope="word_embedding")

            # concat word embedding and picture embedding
            decoder_input = tf.concat([input_words, picture_emb], axis=2, name="concat_full_embedding")


            # encode one word+picture
            decoder_lstm_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
                    config['num_lstm_units'],
                    layer_norm=False,
                    dropout_keep_prob=1.0,
                    reuse=reuse)


            self.decoder_output, self.decoder_state = tf.nn.dynamic_rnn(
                cell=decoder_lstm_cell,
                inputs=decoder_input,
                dtype=tf.float32,
                initial_state=decoder_initial_state,
                sequence_length=input_seq_length,
                scope="word_decoder")  # TODO: use multi-layer RNN

            max_sequence = tf.reduce_max(self.seq_length)

            # compute the softmax for evaluation
            with tf.variable_scope('decoder_output'):
                flat_decoder_output = tf.reshape(self.decoder_output, [-1, decoder_lstm_cell.output_size])
                flat_mlp_output = utils.fully_connected(flat_decoder_output, num_words)

                # retrieve the batch/dialogue format
                mlp_output = tf.reshape(flat_mlp_output, [tf.shape(self.seq_length)[0], max_sequence - 1, num_words])  # Ignore th STOP token

                self.softmax_output = tf.nn.softmax(mlp_output, name="softmax")
                self.argmax_output = tf.argmax(mlp_output, axis=2)

                self.cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=mlp_output, labels=target_words)

            # compute the maximum likelihood loss
            with tf.variable_scope('ml_loss'):

                ml_loss = tf.identity(self.cross_entropy_loss)
                ml_loss *= self.answer_mask[:, 1:]  # remove answers (ignore the <stop> token)
                ml_loss *= self.padding_mask[:, 1:]  # remove padding (ignore the <start> token)

                # Count number of unmask elements
                count = tf.reduce_sum(self.padding_mask) - tf.reduce_sum(1 - self.answer_mask[:, :-1]) - 1  # no_unpad - no_qa - START token

                ml_loss = tf.reduce_sum(ml_loss, axis=1)  # reduce over dialogue dimension
                ml_loss = tf.reduce_sum(ml_loss, axis=0)  # reduce over minibatch dimension
                self.ml_loss = ml_loss / count  # Normalize

                self.loss = self.ml_loss

            # Compute policy gradient
            if policy_gradient:

                with tf.variable_scope('rl_baseline'):
                    decoder_out = tf.stop_gradient(self.decoder_output)  # take the LSTM output (and stop the gradient!)

                    flat_decoder_output = tf.reshape(decoder_out, [-1, decoder_lstm_cell.output_size])  #
                    flat_h1 = utils.fully_connected(flat_decoder_output, n_out=100, activation='relu', scope='baseline_hidden')
                    flat_baseline = utils.fully_connected(flat_h1, 1, activation='relu', scope='baseline_out')

                    self.baseline = tf.reshape(flat_baseline, [tf.shape(self.seq_length)[0], max_sequence-1])
                    self.baseline *= self.answer_mask[:, 1:]
                    self.baseline *= self.padding_mask[:, 1:]


                with tf.variable_scope('policy_gradient_loss'):

                    # Compute log_prob
                    self.log_of_policy = tf.identity(self.cross_entropy_loss)
                    self.log_of_policy *= self.answer_mask[:, 1:]  # remove answers (<=> predicted answer has maximum reward) (ignore the START token in the mask)
                    # No need to use padding mask as the discounted_reward is already zero once the episode terminated

                    # Policy gradient loss
                    rewards *= self.answer_mask[:, 1:]
                    self.score_function = tf.multiply(self.log_of_policy, rewards - self.baseline)  # score function

                    self.baseline_loss = tf.reduce_sum(tf.square(rewards - self.baseline))

                    self.policy_gradient_loss = tf.reduce_sum(self.score_function, axis=1)  # sum over the dialogue trajectory
                    self.policy_gradient_loss = tf.reduce_mean(self.policy_gradient_loss, axis=0)  # reduce over minibatch dimension

                    self.loss = self.policy_gradient_loss
Beispiel #3
0
    def __init__(self, config, num_words, device='', reuse=False):
        AbstractNetwork.__init__(self, "guesser", device=device)

        mini_batch_size = None

        with tf.variable_scope(self.scope_name, reuse=reuse):

            # Dialogues
            self.dialogues = tf.placeholder(tf.int32, [mini_batch_size, None],
                                            name='dialogues')
            self.seq_length = tf.placeholder(tf.int32, [mini_batch_size],
                                             name='seq_length')

            # Objects
            self.obj_mask = tf.placeholder(tf.float32, [mini_batch_size, None],
                                           name='obj_mask')
            self.obj_cats = tf.placeholder(tf.int32, [mini_batch_size, None],
                                           name='obj_cats')
            self.obj_spats = tf.placeholder(
                tf.float32, [mini_batch_size, None, config['spat_dim']],
                name='obj_spats')

            # Targets
            self.targets = tf.placeholder(tf.int32, [mini_batch_size],
                                          name="targets_index")

            self.object_cats_emb = utils.get_embedding(
                self.obj_cats,
                config['no_categories'] + 1,
                config['cat_emb_dim'],
                scope='cat_embedding')

            self.objects_input = tf.concat(
                [self.object_cats_emb, self.obj_spats], axis=2)
            self.flat_objects_inp = tf.reshape(
                self.objects_input,
                [-1, config['cat_emb_dim'] + config['spat_dim']])

            with tf.variable_scope('obj_mlp'):
                h1 = utils.fully_connected(self.flat_objects_inp,
                                           n_out=config['obj_mlp_units'],
                                           activation='relu',
                                           scope='l1')
                h2 = utils.fully_connected(h1,
                                           n_out=config['dialog_emb_dim'],
                                           activation='relu',
                                           scope='l2')

            obj_embs = tf.reshape(
                h2,
                [-1, tf.shape(self.obj_cats)[1], config['dialog_emb_dim']])

            # Compute the word embedding
            input_words = utils.get_embedding(self.dialogues,
                                              n_words=num_words,
                                              n_dim=config['word_emb_dim'],
                                              scope="input_word_embedding")

            last_states, _ = rnn.variable_length_LSTM(
                input_words,
                num_hidden=config['num_lstm_units'],
                seq_length=self.seq_length)

            last_states = tf.reshape(last_states,
                                     [-1, config['num_lstm_units'], 1])
            scores = tf.matmul(obj_embs, last_states)
            scores = tf.reshape(scores, [-1, tf.shape(self.obj_cats)[1]])

            def masked_softmax(scores, mask):
                # subtract max for stability
                scores = scores - tf.tile(
                    tf.reduce_max(scores, axis=(1, ), keep_dims=True),
                    [1, tf.shape(scores)[1]])
                # compute padded softmax
                exp_scores = tf.exp(scores)
                exp_scores *= mask
                exp_sum_scores = tf.reduce_sum(exp_scores,
                                               axis=1,
                                               keep_dims=True)
                return exp_scores / tf.tile(exp_sum_scores,
                                            [1, tf.shape(exp_scores)[1]])

            self.softmax = masked_softmax(scores, self.obj_mask)
            self.selected_object = tf.argmax(self.softmax, axis=1)

            self.loss = tf.reduce_mean(
                utils.cross_entropy(self.softmax, self.targets))
            self.error = tf.reduce_mean(utils.error(self.softmax,
                                                    self.targets))
    def __init__(self, config, num_words, device='', reuse=False):
        AbstractNetwork.__init__(self, "guesser", device=device)

        batch_size = None

        with tf.variable_scope(self.scope_name, reuse=reuse):

            self._is_training = tf.placeholder(tf.bool, name="is_training")

            dropout_keep_scalar = float(config["dropout_keep_prob"])
            dropout_keep = tf.cond(self._is_training,
                                   lambda: tf.constant(dropout_keep_scalar),
                                   lambda: tf.constant(1.0))

            #####################
            #   DIALOGUE
            #####################

            self._dialogue = tf.placeholder(tf.int32, [batch_size, None],
                                            name='dialogue')
            self._seq_length = tf.placeholder(tf.int32, [batch_size],
                                              name='seq_length_dialogue')

            word_emb = tfc_layers.embed_sequence(
                ids=self._dialogue,
                vocab_size=num_words,
                embed_dim=config["question"]["word_embedding_dim"],
                scope="word_embedding",
                reuse=reuse)

            if config["question"]['glove']:
                self._glove = tf.placeholder(tf.float32, [None, None, 300],
                                             name="glove")
                word_emb = tf.concat([word_emb, self._glove], axis=2)

            _, self.dialogue_embedding = rnn.rnn_factory(
                inputs=word_emb,
                seq_length=self._seq_length,
                cell=config['question']["cell"],
                num_hidden=config['question']["rnn_units"],
                bidirectional=config["question"]["bidirectional"],
                max_pool=config["question"]["max_pool"],
                layer_norm=config["question"]["layer_norm"],
                reuse=reuse)

            #####################
            #   IMAGE
            #####################

            self.img_embedding = None
            if config['inputs']['image']:

                self._image = tf.placeholder(tf.float32, [batch_size] +
                                             config['image']["dim"],
                                             name='image')

                # get image
                self.img_embedding = get_image_features(
                    image=self._image,
                    is_training=self._is_training,
                    config=config['image'])

                # pool image feature if needed
                if len(self.img_embedding.get_shape()) > 2:
                    with tf.variable_scope("image_pooling"):
                        self.img_embedding = get_attention(
                            self.img_embedding,
                            self.dialogue_embedding,
                            is_training=self._is_training,
                            config=config["pooling"],
                            dropout_keep=dropout_keep,
                            reuse=reuse)

                # fuse vision/language
                self.visdiag_embedding = get_fusion_mechanism(
                    input1=self.dialogue_embedding,
                    input2=self.img_embedding,
                    config=config.get["fusion"],
                    dropout_keep=dropout_keep)
            else:
                self.visdiag_embedding = self.dialogue_embedding

            visdiag_dim = int(self.visdiag_embedding.get_shape()[-1])

            #####################
            #   OBJECTS
            #####################

            self._num_object = tf.placeholder(tf.int32, [batch_size],
                                              name='obj_seq_length')
            self._obj_cats = tf.placeholder(tf.int32, [batch_size, None],
                                            name='obj_cat')
            self._obj_spats = tf.placeholder(tf.float32, [batch_size, None, 8],
                                             name='obj_spat')

            cats_emb = tfc_layers.embed_sequence(
                ids=self._obj_cats,
                vocab_size=config['category']["n_categories"] +
                1,  # we add the unknown category
                embed_dim=config['category']["embedding_dim"],
                scope="cat_embedding",
                reuse=reuse)
            '''
            spatial_emb = tfc_layers.fully_connected(self._obj_spats,
                                                     num_outputs=config["spatial"]["no_mlp_units"],
                                                     activation_fn=tf.nn.relu,
                                                     reuse=reuse,
                                                     scope="spatial_upsampling")
            '''
            spatial_emb = self._obj_spats

            self.objects_input = tf.concat([cats_emb, spatial_emb], axis=2)
            # self.objects_input = tf.nn.dropout(self.objects_input, dropout_keep)

            with tf.variable_scope('obj_mlp'):
                h1 = tfc_layers.fully_connected(
                    self.objects_input,
                    num_outputs=config["object"]['no_mlp_units'],
                    activation_fn=tf.nn.relu,
                    scope='l1')
                # h1 = tf.nn.dropout(h1, dropout_keep)

                obj_embeddings = tfc_layers.fully_connected(
                    h1,
                    num_outputs=visdiag_dim,
                    activation_fn=tf.nn.relu,
                    scope='l2')

            #####################
            #   SCORES
            #####################

            self.scores = obj_embeddings * tf.expand_dims(
                self.visdiag_embedding, axis=1)
            self.scores = tf.reduce_sum(self.scores, axis=2)

            # remove max for stability (trick)
            self.scores -= tf.reduce_max(self.scores, axis=1, keep_dims=True)

            with tf.variable_scope('object_mask', reuse=reuse):

                object_mask = tf.sequence_mask(self._num_object)
                score_mask_values = float("-inf") * tf.ones_like(self.scores)

                self.score_masked = tf.where(object_mask, self.scores,
                                             score_mask_values)

            #####################
            #   LOSS
            #####################

            # Targets
            self._targets = tf.placeholder(tf.int32, [batch_size],
                                           name="target_index")

            self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=self._targets, logits=self.score_masked)
            self.loss = tf.reduce_mean(self.loss)

            self.selected_object = tf.argmax(self.score_masked, axis=1)
            self.softmax = tf.nn.softmax(self.score_masked)

            with tf.variable_scope('accuracy'):
                self.accuracy = tf.equal(self.selected_object,
                                         tf.cast(self._targets, tf.int64))
                self.accuracy = tf.reduce_mean(
                    tf.cast(self.accuracy, tf.float32))
    def __init__(self, config, num_words, device='', reuse=False):
        AbstractNetwork.__init__(self, "guesser", device=device)

        mini_batch_size = None

        with tf.variable_scope(self.scope_name, reuse=reuse):
            # Misc
            self._is_training = tf.placeholder(tf.bool, name='is_training')
            self._is_dynamic = tf.placeholder(tf.bool, name='is_dynamic')
            batch_size = None

            dropout_keep_scalar = float(config.get("dropout_keep_prob", 1.0))
            dropout_keep = tf.cond(self._is_training,
                                   lambda: tf.constant(dropout_keep_scalar),
                                   lambda: tf.constant(1.0))
            # Objects
            self._num_object = tf.placeholder(tf.int32, [mini_batch_size],
                                              name='obj_seq_length')
            self.obj_mask = tf.sequence_mask(self._num_object,
                                             dtype=tf.float32)
            # self.obj_mask = tf.sequence_mask(self._num_object, maxlen=20, dtype=tf.float32)
            # self.obj_mask = tf.placeholder(tf.float32, [mini_batch_size, None], name='obj_mask')
            self.obj_cats = tf.placeholder(tf.int32, [mini_batch_size, None],
                                           name='obj_cat')
            self.obj_spats = tf.placeholder(
                tf.float32, [mini_batch_size, None, config['spat_dim']],
                name='obj_spat')

            # Targets
            self.targets = tf.placeholder(tf.int32, [mini_batch_size],
                                          name="target_index")

            self.object_cats_emb = utils_v1.get_embedding(
                self.obj_cats,
                config['no_categories'] + 1,
                config['cat_emb_dim'],
                scope='cat_embedding')

            self.objects_input = tf.concat(
                [self.object_cats_emb, self.obj_spats], axis=2)
            self.flat_objects_inp = tf.reshape(
                self.objects_input,
                [-1, config['cat_emb_dim'] + config['spat_dim']])

            with tf.variable_scope('obj_mlp'):
                h1 = utils_v1.fully_connected(self.flat_objects_inp,
                                              n_out=config['obj_mlp_units'],
                                              activation='relu',
                                              scope='l1')
                h2 = utils_v1.fully_connected(h1,
                                              n_out=config['dialog_emb_dim'],
                                              activation='relu',
                                              scope='l2')

            obj_embs = tf.reshape(
                h2,
                [-1, tf.shape(self.obj_cats)[1], config['dialog_emb_dim']])

            #####################
            #   UAQRAH PART
            #####################

            #####################
            #   WORD EMBEDDING
            #####################

            with tf.variable_scope('word_embedding', reuse=reuse):
                self.dialogue_emb_weights = tf.get_variable(
                    "dialogue_embedding_encoder",
                    shape=[
                        num_words, config["dialogue"]["word_embedding_dim"]
                    ],
                    initializer=tf.random_uniform_initializer(-0.08, 0.08))

            #####################
            #   DIALOGUE
            #####################

            self._q_his = tf.placeholder(tf.int32, [batch_size, None, None],
                                         name='q_his')
            # self._q_his_mask = tf.placeholder(tf.float32, [batch_size, None, None], name='q_his_mask')
            self._a_his = tf.placeholder(tf.int32, [batch_size, None, None],
                                         name='a_his')
            self._q_his_lengths = tf.placeholder(tf.int32, [batch_size, None],
                                                 name='q_his_lengths')
            self._q_turn = tf.placeholder(tf.int32, [batch_size],
                                          name='q_turn')
            self._max_turn = tf.placeholder(tf.int32, None, name='max_turn')

            bs = tf.shape(self._q_his)[0]

            self.rnn_cell_word = rnn.create_cell(
                config['dialogue']["rnn_word_units"],
                layer_norm=config["dialogue"]["layer_norm"],
                reuse=tf.AUTO_REUSE,
                cell=config['dialogue']["cell"],
                scope="forward_word")

            self.rnn_cell_context = rnn.create_cell(
                config['dialogue']["rnn_context_units"],
                layer_norm=config["dialogue"]["layer_norm"],
                reuse=tf.AUTO_REUSE,
                cell=config['dialogue']["cell"],
                scope="forward_context")
            self.rnn_cell_pair = rnn.create_cell(
                config['dialogue']["rnn_context_units"],
                layer_norm=config["dialogue"]["layer_norm"],
                reuse=tf.AUTO_REUSE,
                cell=config['dialogue']["cell"],
                scope="forward_pair")

            # ini
            dialogue_last_states_ini = self.rnn_cell_context.zero_state(
                bs, dtype=tf.float32)
            pair_states_ini = self.rnn_cell_pair.zero_state(bs,
                                                            dtype=tf.float32)

            self.max_turn_loop = tf.cond(
                self._is_dynamic, lambda: self._max_turn,
                lambda: tf.constant(5, dtype=tf.int32))
            self.turn_0 = tf.constant(0, dtype=tf.int32)
            self.m_num = tf.Variable(1.)
            self.a_yes_token = tf.constant(0, dtype=tf.int32)
            self.a_no_token = tf.constant(1, dtype=tf.int32)
            self.a_na_token = tf.constant(2, dtype=tf.int32)

            #####################
            #   IMAGES
            #####################

            self._image = tf.placeholder(tf.float32,
                                         [batch_size] + config['image']["dim"],
                                         name='image')  # B, 36, 2048
            if config['image'].get('normalize', False):
                self.img_feature = tf.nn.l2_normalize(self._image,
                                                      dim=-1,
                                                      name="img_normalization")
            else:
                self.img_feature = self._image
            self.vis_dif = OD_compute(self.img_feature)

            self.att_ini = tf.ones([bs, 36])

            def compute_v_dif(att):
                max_obj = tf.argmax(att, axis=1)
                obj_oneh = tf.one_hot(max_obj, depth=36)  # 64,36
                vis_dif_select = tf.boolean_mask(
                    self.vis_dif, obj_oneh)  # 64,36,36*2048 to 64,36*2048
                vis_dif_select = tf.reshape(vis_dif_select, [bs, 36, 2048])
                vis_dif_weighted = tf.reduce_sum(vis_dif_select *
                                                 tf.expand_dims(att, -1),
                                                 axis=1)
                vis_dif_weighted = tf.nn.l2_normalize(
                    vis_dif_weighted, dim=-1, name="v_dif_normalization")
                return vis_dif_weighted

            turn_i = tf.constant(0, dtype=tf.int32)

            def uaqra_att_ini(vis_fea, q_fea, dialogue_state, config,
                              is_training):
                with tf.variable_scope("q_guide_image_pooling"):
                    att_q = compute_current_att(vis_fea,
                                                q_fea,
                                                config,
                                                is_training,
                                                reuse=reuse)
                    att_q = tf.nn.softmax(att_q, axis=-1)

                with tf.variable_scope("h_guide_image_pooling"):
                    att_h = compute_current_att(self.img_feature,
                                                dialogue_state,
                                                config,
                                                is_training,
                                                reuse=reuse)
                    att_h = tf.nn.softmax(att_h, axis=-1)
                att = att_q + att_h
                return att

            def uaqra_att(vis_fea, q_fea, dialogue_state, att_prev, config,
                          answer, m_num, is_training):
                with tf.variable_scope("q_guide_image_pooling"):
                    att_q = compute_current_att(vis_fea,
                                                q_fea,
                                                config,
                                                is_training,
                                                reuse=True)
                f_att_q = cond_gumbel_softmax(is_training, att_q)
                a_list = tf.reshape(answer, shape=[-1, 1])  # is it ok?
                a_list = a_list - 5
                att_na = att_prev
                att_yes = f_att_q * att_prev
                att_no = (1 - f_att_q) * att_prev
                att_select_na = tf.where(a_list == self.a_na_token, att_na,
                                         att_no)
                att_select_yes = tf.where(a_list == self.a_yes_token, att_yes,
                                          att_select_na)
                att_select = att_select_yes
                att_norm = tf.nn.l2_normalize(att_select,
                                              dim=-1,
                                              name="att_normalization")
                att_enlarged = att_norm * m_num
                att_mask = tf.greater(att_enlarged, 0.)
                att_new = maskedSoftmax(att_enlarged, att_mask)
                with tf.variable_scope("h_guide_image_pooling"):
                    att_h = compute_current_att(self.img_feature,
                                                dialogue_state,
                                                config,
                                                is_training,
                                                reuse=True)
                    att_h = tf.nn.softmax(att_h, axis=-1)
                att = att_new + att_h

                return att

            att_list_ini = tf.expand_dims(self.att_ini, 0)
            hpair_list_ini = tf.expand_dims(pair_states_ini, 0)

            def cond_loop(cur_turn, att_prev, dialogue_state, att_list,
                          hpair_list):
                return tf.less(cur_turn, self.max_turn_loop)

            def dialog_flow(cur_turn, att_prev, dialogue_state, att_list,
                            hpair_list):

                #####################
                #   ENCODE CUR_TURN
                #####################

                self._question = self._q_his[:, cur_turn, :]
                self._answer = self._a_his[:, cur_turn]
                self._seq_length_question = self._q_his_lengths[:, cur_turn]

                self.word_emb_question = tf.nn.embedding_lookup(
                    params=self.dialogue_emb_weights, ids=self._question)
                self.word_emb_question = tf.nn.dropout(self.word_emb_question,
                                                       dropout_keep)

                self.word_emb_answer = tf.nn.embedding_lookup(
                    params=self.dialogue_emb_weights, ids=self._answer)
                self.word_emb_answer = tf.nn.dropout(self.word_emb_answer,
                                                     dropout_keep)
                ''' Update the dialog state '''
                seq_mask = tf.cast(tf.greater(self._seq_length_question, 1),
                                   self._seq_length_question.dtype)

                self.outputs_wq, self.h_wq = tf.nn.dynamic_rnn(
                    cell=self.rnn_cell_word,
                    inputs=self.word_emb_question,
                    dtype=tf.float32,
                    sequence_length=self._seq_length_question - 1,
                    scope="forward_word")
                _, self.h_wa = tf.nn.dynamic_rnn(cell=self.rnn_cell_word,
                                                 inputs=self.word_emb_answer,
                                                 dtype=tf.float32,
                                                 sequence_length=seq_mask,
                                                 scope="forward_word")
                _, self.h_c1 = tf.nn.dynamic_rnn(cell=self.rnn_cell_context,
                                                 inputs=tf.expand_dims(
                                                     self.h_wq, 1),
                                                 initial_state=dialogue_state,
                                                 dtype=tf.float32,
                                                 sequence_length=seq_mask,
                                                 scope="forward_context")
                _, dialogue_state = tf.nn.dynamic_rnn(
                    cell=self.rnn_cell_context,
                    inputs=tf.expand_dims(self.h_wa, 1),
                    initial_state=self.h_c1,
                    dtype=tf.float32,
                    sequence_length=seq_mask,
                    scope="forward_context")
                _, self.h_pq = tf.nn.dynamic_rnn(cell=self.rnn_cell_pair,
                                                 inputs=tf.expand_dims(
                                                     self.h_wq, 1),
                                                 dtype=tf.float32,
                                                 scope="forward_pair")
                _, h_pair = tf.nn.dynamic_rnn(cell=self.rnn_cell_pair,
                                              inputs=tf.expand_dims(
                                                  self.h_wa, 1),
                                              initial_state=self.h_pq,
                                              dtype=tf.float32,
                                              scope="forward_pair")

                q_att = compute_q_att(self.outputs_wq, dropout_keep)

                att = tf.cond(
                    tf.equal(cur_turn, self.turn_0), lambda: uaqra_att_ini(
                        self.img_feature, q_att, dialogue_state, config[
                            "pooling"], self._is_training),
                    lambda: uaqra_att(self.img_feature, q_att, dialogue_state,
                                      att_prev, config["pooling"], self.
                                      _answer, self.m_num, self._is_training))

                att_list = tf.cond(
                    tf.equal(cur_turn,
                             self.turn_0), lambda: tf.expand_dims(att, 0),
                    lambda: tf.concat([att_list,
                                       tf.expand_dims(att, 0)], 0))
                hpair_list = tf.cond(
                    tf.equal(cur_turn,
                             self.turn_0), lambda: tf.expand_dims(h_pair, 0),
                    lambda: tf.concat([hpair_list,
                                       tf.expand_dims(h_pair, 0)], 0))

                cur_turn = tf.add(cur_turn, 1)

                return cur_turn, att, dialogue_state, att_list, hpair_list

            _, _, self.dialogue_last_states, self.att_list, self.hpair_list = tf.while_loop(
                cond_loop,
                dialog_flow, [
                    turn_i, self.att_ini, dialogue_last_states_ini,
                    att_list_ini, hpair_list_ini
                ],
                shape_invariants=[
                    turn_i.get_shape(),
                    self.att_ini.get_shape(),
                    dialogue_last_states_ini.get_shape(),
                    tf.TensorShape([None, None, 36]),
                    tf.TensorShape([None, None, 1200])
                ])
            att_list = tf.transpose(self.att_list, perm=[1, 0,
                                                         2])  # 64,max_turn,36
            att_oneh = tf.one_hot(self._q_turn - 1,
                                  depth=self.max_turn_loop)  # 64,max_turn
            self.att = tf.boolean_mask(att_list,
                                       att_oneh)  # 64,max_turn,36 to 64,36
            hpair_list = tf.transpose(self.hpair_list,
                                      [1, 0, 2])  # 64,max_turn,36
            self.h_pair = tf.boolean_mask(hpair_list,
                                          att_oneh)  # 64,max_turn,36 to 64,36
            visual_features = tf.reduce_sum(self.img_feature *
                                            tf.expand_dims(self.att, -1),
                                            axis=1)
            visual_features = tf.nn.l2_normalize(visual_features,
                                                 dim=-1,
                                                 name="v_fea_normalization")
            vis_dif_weighted = compute_v_dif(self.att)

            with tf.variable_scope("compute_beta"):
                self.h_pair = tf.nn.dropout(
                    self.h_pair,
                    dropout_keep)  # considering about the tanh activation
                beta = tfc_layers.fully_connected(self.h_pair,
                                                  num_outputs=2,
                                                  activation_fn=tf.nn.softmax,
                                                  reuse=reuse,
                                                  scope="beta_computation")

            beta_0 = tf.tile(tf.expand_dims(tf.gather(beta, 0, axis=1), 1),
                             [1, 2048])
            beta_1 = tf.tile(tf.expand_dims(tf.gather(beta, 1, axis=1), 1),
                             [1, 2048])
            self.v_final = beta_0 * visual_features + beta_1 * vis_dif_weighted

            with tf.variable_scope("multimodal_fusion"):
                # concat
                self.visdiag_embedding = tfc_layers.fully_connected(
                    tf.concat([self.dialogue_last_states, self.v_final],
                              axis=-1),
                    num_outputs=config['fusion']['projection_size'],
                    activation_fn=tf.nn.tanh,
                    reuse=reuse,
                    scope="visdiag_projection")

            scores = tf.matmul(obj_embs,
                               tf.expand_dims(self.visdiag_embedding, axis=-1))
            scores = tf.reshape(scores, [-1, tf.shape(self.obj_cats)[1]])

            def masked_softmax(scores, mask):
                # subtract max for stability
                scores = scores - tf.tile(
                    tf.reduce_max(scores, axis=(1, ), keep_dims=True),
                    [1, tf.shape(scores)[1]])
                # compute padded softmax
                exp_scores = tf.exp(scores)
                exp_scores *= mask
                exp_sum_scores = tf.reduce_sum(exp_scores,
                                               axis=1,
                                               keep_dims=True)
                return exp_scores / tf.tile(exp_sum_scores,
                                            [1, tf.shape(exp_scores)[1]])

            self.softmax = masked_softmax(scores, self.obj_mask)
            self.selected_object = tf.argmax(self.softmax, axis=1)

            self.loss = tf.reduce_mean(
                utils_v1.cross_entropy(self.softmax, self.targets))
            self.error = tf.reduce_mean(
                utils_v1.error(self.softmax, self.targets))
Beispiel #6
0
    def __init__(self,
                 config,
                 num_words,
                 rl_module,
                 device='',
                 reuse=tf.AUTO_REUSE):
        AbstractNetwork.__init__(self, "qgen", device=device)

        self.rl_module = rl_module

        # Create the scope for this graph
        with tf.variable_scope(self.scope_name, reuse=reuse):

            # Misc
            self._is_training = tf.placeholder(tf.bool, name='is_training')
            self._is_dynamic = tf.placeholder(tf.bool, name='is_dynamic')
            batch_size = None

            dropout_keep_scalar = float(config.get("dropout_keep_prob", 1.0))
            dropout_keep = tf.cond(self._is_training,
                                   lambda: tf.constant(dropout_keep_scalar),
                                   lambda: tf.constant(1.0))
            # dropout_keep = tf.constant(1.0)

            #####################
            #   WORD EMBEDDING
            #####################

            with tf.variable_scope('word_embedding', reuse=reuse):
                self.dialogue_emb_weights = tf.get_variable(
                    "dialogue_embedding_encoder",
                    shape=[
                        num_words, config["dialogue"]["word_embedding_dim"]
                    ])

            #####################
            #   DIALOGUE
            #####################

            if self.rl_module is not None:
                self._q_flag = tf.placeholder(tf.float32,
                                              shape=[None, batch_size],
                                              name='q_flag')
                self._cum_rewards = tf.placeholder(
                    tf.float32,
                    shape=[batch_size, None, None],
                    name='cum_reward')
                # self._skewness = tf.placeholder(tf.float32, shape=[None, batch_size], name='skewness')
                # self._softmax = tf.placeholder(tf.float32, shape=[None, batch_size, None], name='softmax')
                # self._guess_softmax = tf.transpose(self._softmax, [1, 0, 2])  # batch, turn, o_n
            self._q_his = tf.placeholder(tf.int32, [batch_size, None, None],
                                         name='q_his')
            # self._q_his_mask = tf.placeholder(tf.float32, [batch_size, None, None], name='q_his_mask')
            self._a_his = tf.placeholder(tf.int32, [batch_size, None, None],
                                         name='a_his')
            self._q_his_lengths = tf.placeholder(tf.int32, [batch_size, None],
                                                 name='q_his_lengths')
            self._q_turn = tf.placeholder(tf.int32, None, name='q_turn')

            bs = tf.shape(self._q_his)[0]

            self.rnn_cell_word = rnn.create_cell(
                config['dialogue']["rnn_word_units"],
                layer_norm=config["dialogue"]["layer_norm"],
                reuse=tf.AUTO_REUSE,
                cell=config['dialogue']["cell"],
                scope="forward_word")
            # self.word_init_states = rnn_cell_word.zero_state(batch_size, dtype=tf.float32)

            self.rnn_cell_context = rnn.create_cell(
                config['dialogue']["rnn_context_units"],
                layer_norm=config["dialogue"]["layer_norm"],
                reuse=tf.AUTO_REUSE,
                cell=config['dialogue']["cell"],
                scope="forward_context")
            self.rnn_cell_pair = rnn.create_cell(
                config['dialogue']["rnn_context_units"],
                layer_norm=config["dialogue"]["layer_norm"],
                reuse=tf.AUTO_REUSE,
                cell=config['dialogue']["cell"],
                scope="forward_pair")

            # ini
            dialogue_last_states_ini = self.rnn_cell_context.zero_state(
                bs, dtype=tf.float32)
            pair_states_ini = self.rnn_cell_pair.zero_state(bs,
                                                            dtype=tf.float32)
            q_ini = tf.nn.embedding_lookup(params=self.dialogue_emb_weights,
                                           ids=tf.fill([bs], num_words))
            outputs_w_ini, _ = tf.nn.dynamic_rnn(cell=self.rnn_cell_word,
                                                 inputs=tf.expand_dims(q_ini,
                                                                       axis=1),
                                                 dtype=tf.float32,
                                                 scope="forward_word")
            qatt_ini = compute_q_att(outputs_w_ini, dropout_keep)
            answer_ini = self._a_his[:, 0]

            assert config['decoder'][
                "cell"] != "lstm", "LSTM are not yet supported for the decoder"
            self.decoder_cell = rnn.create_cell(
                cell=config['decoder']["cell"],
                num_units=config['fusion']['projection_size'],
                layer_norm=config["decoder"]["layer_norm"],
                reuse=reuse)

            self.decoder_projection_layer = tf.layers.Dense(num_words - 1)

            loss_ini = 0.

            self.q_turn_loop = tf.cond(self._is_dynamic, lambda: self._q_turn,
                                       lambda: tf.constant(5, dtype=tf.int32))
            self.turn_0 = tf.constant(0, dtype=tf.int32)
            self.m_num = tf.Variable(1.)
            self.a_yes_token = tf.constant(0, dtype=tf.int32)
            self.a_no_token = tf.constant(1, dtype=tf.int32)
            self.a_na_token = tf.constant(2, dtype=tf.int32)

            #####################
            #   IMAGES
            #####################

            self._image = tf.placeholder(tf.float32,
                                         [batch_size] + config['image']["dim"],
                                         name='image')  # B, 36, 2048
            if config['image'].get('normalize', False):
                self.img_feature = tf.nn.l2_normalize(self._image,
                                                      dim=-1,
                                                      name="fc_normalization")
            self.vis_dif = OD_compute(self._image)

            # Pool Image Features
            with tf.variable_scope("q_guide_image_pooling"):
                att_ini_q = compute_current_att(self.img_feature,
                                                qatt_ini,
                                                config["pooling"],
                                                self._is_training,
                                                reuse=reuse)
                att_ini_q = tf.nn.softmax(att_ini_q, axis=-1)

            with tf.variable_scope("h_guide_image_pooling"):
                att_ini_h = compute_current_att(self.img_feature,
                                                dialogue_last_states_ini,
                                                config["pooling"],
                                                self._is_training,
                                                reuse=reuse)
                att_ini_h = tf.nn.softmax(att_ini_h, axis=-1)

            self.att_ini = att_ini_h + att_ini_q
            visual_features_ini = tf.reduce_sum(
                self.img_feature * tf.expand_dims(self.att_ini, -1), axis=1)
            visual_features_ini = tf.nn.l2_normalize(
                visual_features_ini, dim=-1, name="v_fea_normalization")

            def compute_v_dif(att):
                max_obj = tf.argmax(att, axis=1)
                obj_oneh = tf.one_hot(max_obj, depth=36)  # 64,36
                vis_dif_select = tf.boolean_mask(
                    self.vis_dif, obj_oneh)  # 64,36,36*2048 to 64,36*2048
                vis_dif_select = tf.reshape(vis_dif_select, [bs, 36, 2048])
                vis_dif_weighted = tf.reduce_sum(vis_dif_select *
                                                 tf.expand_dims(att, -1),
                                                 axis=1)
                vis_dif_weighted = tf.nn.l2_normalize(
                    vis_dif_weighted, dim=-1, name="v_dif_normalization")
                return vis_dif_weighted

            vis_dif_weighted_ini = compute_v_dif(self.att_ini)

            with tf.variable_scope("compute_beta"):
                pair_states_ini = tf.nn.dropout(pair_states_ini, dropout_keep)
                beta_ini = tfc_layers.fully_connected(
                    pair_states_ini,
                    num_outputs=2,
                    activation_fn=tf.nn.softmax,
                    reuse=reuse,
                    scope="beta_computation")
            beta_0_ini = tf.tile(
                tf.expand_dims(tf.gather(beta_ini, 0, axis=1), 1), [1, 2048])
            beta_1_ini = tf.tile(
                tf.expand_dims(tf.gather(beta_ini, 1, axis=1), 1), [1, 2048])

            v_final_ini = beta_0_ini * visual_features_ini + beta_1_ini * vis_dif_weighted_ini

            turn_i = 0
            totals_ini = 0.

            with tf.variable_scope("multimodal_fusion"):

                visdiag_embedding_ini = tfc_layers.fully_connected(
                    tf.concat([dialogue_last_states_ini, v_final_ini],
                              axis=-1),
                    num_outputs=config['fusion']['projection_size'],
                    activation_fn=tf.nn.tanh,
                    reuse=reuse,
                    scope="visdiag_projection")

            def uaqra_att(vis_fea, q_fea, dialogue_state, h_pair, att_prev,
                          config, answer, m_num, is_training):
                with tf.variable_scope("q_guide_image_pooling"):
                    att_q = compute_current_att(vis_fea,
                                                q_fea,
                                                config,
                                                is_training,
                                                reuse=True)
                f_att_q = cond_gumbel_softmax(is_training, att_q)
                a_list = tf.reshape(answer, shape=[-1, 1])  # is it ok?
                a_list = a_list - 5
                att_na = att_prev
                att_yes = f_att_q * att_prev
                att_no = (1 - f_att_q) * att_prev
                att_select_na = tf.where(a_list == self.a_na_token, att_na,
                                         att_no)
                att_select_yes = tf.where(a_list == self.a_yes_token, att_yes,
                                          att_select_na)
                att_select = att_select_yes
                att_norm = tf.nn.l2_normalize(att_select,
                                              dim=-1,
                                              name="att_normalization")
                att_enlarged = att_norm * m_num
                att_mask = tf.greater(att_enlarged, 0.)
                att_new = maskedSoftmax(att_enlarged, att_mask)
                with tf.variable_scope("h_guide_image_pooling"):
                    att_h = compute_current_att(self.img_feature,
                                                dialogue_state,
                                                config,
                                                is_training,
                                                reuse=True)
                    att_h = tf.nn.softmax(att_h, axis=-1)
                att = att_new + att_h

                visual_features = tf.reduce_sum(self.img_feature *
                                                tf.expand_dims(att, -1),
                                                axis=1)
                visual_features = tf.nn.l2_normalize(
                    visual_features, dim=-1, name="v_fea_normalization")
                vis_dif_weighted = compute_v_dif(att)

                with tf.variable_scope("compute_beta"):
                    h_pair = tf.nn.dropout(
                        h_pair,
                        dropout_keep)  # considering about the tanh activation
                    beta = tfc_layers.fully_connected(
                        h_pair,
                        num_outputs=2,
                        activation_fn=tf.nn.softmax,
                        reuse=True,
                        scope="beta_computation")

                beta_0 = tf.tile(tf.expand_dims(tf.gather(beta, 0, axis=1), 1),
                                 [1, 2048])
                beta_1 = tf.tile(tf.expand_dims(tf.gather(beta, 1, axis=1), 1),
                                 [1, 2048])
                v_final = beta_0 * visual_features + beta_1 * vis_dif_weighted

                return att, v_final, beta

            def cond_loop(cur_turn, loss, totals, q_att, h_pair, answer,
                          att_prev, dialogue_state, visual_features,
                          visdiag_embedding, beta):
                # print_info = tf.Print(cur_turn, [cur_turn], "x:")
                # cur_turn = cur_turn + print_info
                return tf.less(cur_turn, self.q_turn_loop)

            def dialog_flow(cur_turn, loss, totals, q_att, h_pair, answer,
                            att_prev, dialogue_state, visual_features,
                            visdiag_embedding, beta):

                att, v_final, beta = tf.cond(
                    tf.equal(cur_turn, self.turn_0), lambda:
                    (self.att_ini, visual_features, beta_ini),
                    lambda: uaqra_att(self.img_feature, q_att, dialogue_state,
                                      h_pair, att_prev, config["pooling"],
                                      answer, self.m_num, self._is_training))
                repeat_v_final = tf.tile(tf.expand_dims(v_final, axis=1),
                                         [1, 12, 1])

                with tf.variable_scope("multimodal_fusion"):
                    # concat
                    visdiag_embedding = tfc_layers.fully_connected(
                        tf.concat([dialogue_state, v_final], axis=-1),
                        num_outputs=config['fusion']['projection_size'],
                        activation_fn=tf.nn.tanh,
                        reuse=True,
                        scope="visdiag_projection")

                #####################
                #   TARGET QUESTION
                #####################

                self._question = self._q_his[:, cur_turn, :]
                self._answer = self._a_his[:, cur_turn]
                self._seq_length_question = self._q_his_lengths[:, cur_turn]
                self._mask = tf.sequence_mask(
                    lengths=self._seq_length_question - 1,
                    maxlen=12,
                    dtype=tf.float32)

                self.word_emb_question = tf.nn.embedding_lookup(
                    params=self.dialogue_emb_weights, ids=self._question)
                self.word_emb_question = tf.nn.dropout(self.word_emb_question,
                                                       dropout_keep)

                self.word_emb_question_input = self.word_emb_question[:, :
                                                                      -1, :]
                self.word_emb_question_encode = self.word_emb_question[:,
                                                                       1:, :]

                self.word_emb_answer = tf.nn.embedding_lookup(
                    params=self.dialogue_emb_weights, ids=self._answer)
                self.word_emb_answer = tf.nn.dropout(self.word_emb_answer,
                                                     dropout_keep)

                #####################
                #   DECODER
                #####################

                self.decoder_states, _ = tf.nn.dynamic_rnn(
                    cell=self.decoder_cell,
                    inputs=tf.concat(
                        [self.word_emb_question_input, repeat_v_final],
                        axis=-1),
                    dtype=tf.float32,
                    initial_state=visdiag_embedding,
                    sequence_length=self._seq_length_question - 1,
                    scope="decoder")

                self.decoder_outputs = self.decoder_projection_layer(
                    self.decoder_states)

                #####################
                #   LOSS
                #####################

                # compute the softmax for evaluation
                ''' Compute policy gradient '''
                if self.rl_module is not None:

                    # Step 1: compute the state-value function
                    self._cum_rewards_current = self._cum_rewards[:, cur_turn,
                                                                  1:]
                    self._cum_rewards_current *= self._mask

                    # q cost
                    # self._q_flag_cur = self._q_flag[cur_turn, :]

                    # guess softmax
                    # self._guess_cur = self._guess_softmax[:, cur_turn+1, :]
                    # self._guess_pre = self._guess_softmax[:, cur_turn, :]
                    # skewness
                    # self._skewness_cur = self._skewness[cur_turn, :]
                    # Step 2: compute the state-value function
                    value_state = self.decoder_states
                    if self.rl_module.stop_gradient:
                        value_state = tf.stop_gradient(self.decoder_states)
                    # value_state = tf.nn.dropout(value_state, dropout_keep)
                    v_num_hidden_units = int(
                        int(value_state.get_shape()[-1]) / 4)

                    with tf.variable_scope('value_function'):
                        self.value_function = tf.keras.models.Sequential()
                        # self.value_function.add(tf.layers.Dropout(rate=dropout_keep))
                        self.value_function.add(
                            tf.layers.Dense(
                                units=v_num_hidden_units,
                                activation=tf.nn.relu,
                                input_shape=(int(
                                    value_state.get_shape()[-1]), ),
                                name="value_function_hidden"))
                        # self.value_function.add(tf.layers.Dropout(rate=dropout_keep))
                        self.value_function.add(
                            tf.layers.Dense(units=1,
                                            activation=None,
                                            name="value_function"))
                        self.value_function.add(tf.keras.layers.Reshape(
                            (-1, )))

                    # Step 3: compute the RL loss (reinforce, A3C, PPO etc.)
                    loss_i = rl_module(
                        cum_rewards=self._cum_rewards_current,
                        value_function=self.value_function(value_state),
                        policy_state=self.decoder_outputs,
                        actions=self._question[:, 1:],
                        action_mask=self._mask)

                    # q_flag=self._q_flag_cur,
                    # pre_g=self._guess_pre,
                    # cur_g=self._guess_cur,
                    # skew=self._skewness_cur)

                    loss += loss_i
                    totals = 1.0

                else:
                    '''supervised loss'''
                    with tf.variable_scope('ml_loss'):
                        cr_loss = tfc_seq.sequence_loss(
                            logits=self.decoder_outputs,
                            targets=self._question[:, 1:],
                            weights=self._mask,
                            average_across_timesteps=False,
                            average_across_batch=False)

                        # cr_loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.decoder_outputs,
                        #                                                          labels=self._question[:, 1:])
                        ml_loss = tf.identity(cr_loss)
                        ml_loss *= self._mask
                        ml_loss = tf.reduce_sum(
                            ml_loss, axis=1)  # reduce over question dimension
                        ml_loss = tf.reduce_sum(
                            ml_loss, axis=0)  # reduce over minibatch dimension
                        count = tf.reduce_sum(self._mask)
                        self.softmax_output = tf.nn.softmax(
                            self.decoder_outputs, name="softmax")
                        self.argmax_output = tf.argmax(self.decoder_outputs,
                                                       axis=2)
                        loss += ml_loss
                        totals += count
                ''' Update the dialog state '''
                self.outputs_wq, self.h_wq = tf.nn.dynamic_rnn(
                    cell=self.rnn_cell_word,
                    inputs=self.word_emb_question_encode,
                    dtype=tf.float32,
                    sequence_length=self._seq_length_question - 1,
                    scope="forward_word")
                _, self.h_wa = tf.nn.dynamic_rnn(cell=self.rnn_cell_word,
                                                 inputs=self.word_emb_answer,
                                                 dtype=tf.float32,
                                                 scope="forward_word")
                _, self.h_c1 = tf.nn.dynamic_rnn(cell=self.rnn_cell_context,
                                                 inputs=tf.expand_dims(
                                                     self.h_wq, 1),
                                                 initial_state=dialogue_state,
                                                 dtype=tf.float32,
                                                 scope="forward_context")
                _, dialogue_state = tf.nn.dynamic_rnn(
                    cell=self.rnn_cell_context,
                    inputs=tf.expand_dims(self.h_wa, 1),
                    initial_state=self.h_c1,
                    dtype=tf.float32,
                    scope="forward_context")
                _, self.h_pq = tf.nn.dynamic_rnn(cell=self.rnn_cell_pair,
                                                 inputs=tf.expand_dims(
                                                     self.h_wq, 1),
                                                 dtype=tf.float32,
                                                 scope="forward_pair")
                _, h_pair = tf.nn.dynamic_rnn(cell=self.rnn_cell_pair,
                                              inputs=tf.expand_dims(
                                                  self.h_wa, 1),
                                              initial_state=self.h_pq,
                                              dtype=tf.float32,
                                              scope="forward_pair")

                q_att = compute_q_att(self.outputs_wq, dropout_keep)

                cur_turn = tf.add(cur_turn, 1)

                return cur_turn, loss, totals, q_att, h_pair, self._answer, att, dialogue_state, v_final, visdiag_embedding, beta

            _, loss_f, totals_f, _, _, _, self.att, self.dialogue_last_states, self.visual_features, self.visdiag_embedding, self.beta = tf.while_loop(
                cond_loop, dialog_flow, [
                    turn_i, loss_ini, totals_ini, qatt_ini, pair_states_ini,
                    answer_ini, self.att_ini, dialogue_last_states_ini,
                    visual_features_ini, visdiag_embedding_ini, beta_ini
                ])
            self.loss = loss_f / totals_f