def build_attribute_blank_fill(self):
        """
        attribute_blank_fill
        """
        pooled_V_ft = self.mid_result['attribute_pooled_V_ft']

        blank_embed = tf.nn.embedding_lookup(  # [bs, #proposal, len, W_DIM]
            self.l_word_map, self.batch['attr_blank_fill/blanks'])
        blank_len = self.batch['attr_blank_fill/blanks_len']
        blank_maxlen = tf.shape(blank_embed)[-2]
        flat_blank_ft = modules.encode_L(  # [bs * #proposal, L_DIM]
            tf.reshape(blank_embed, [-1, blank_maxlen, W_DIM]),
            tf.reshape(blank_len, [-1]), L_DIM,
            scope='encode_L_blank', cell_type='GRU')
        blank_ft = tf.reshape(
            flat_blank_ft, [-1, self.data_cfg.n_attr_bf, L_DIM])

        v_linear_l = modules.fc_layer(
            pooled_V_ft, L_DIM, use_bias=True, use_bn=False, use_ln=True,
            activation_fn=tf.nn.relu, is_training=self.is_train,
            scope='pooled_linear_l')

        l_linear_l = modules.fc_layer(
            blank_ft, L_DIM, use_bias=True, use_bn=False, use_ln=True,
            activation_fn=tf.nn.relu, is_training=self.is_train,
            scope='q_linear_l')

        joint = modules.fc_layer(
            v_linear_l * l_linear_l, L_DIM * 2,
            use_bias=True, use_bn=False, use_ln=True,
            activation_fn=tf.nn.relu, is_training=self.is_train, scope='joint_fc')
        joint = tf.nn.dropout(joint, 0.5)

        logit = modules.fc_layer(
            joint, self.num_answer,
            use_bias=True, use_bn=False, use_ln=False,
            activation_fn=None, is_training=self.is_train, scope='classifier')
        self.mid_result['attr_blank_fill/logit'] = logit  # [bs, #attr, #answer]

        with tf.name_scope('loss/attr_blank_fill'):
            onehot_gt = tf.one_hot(self.batch['attr_blank_fill/fills'],
                                   depth=self.num_answer)
            num_valid_entry = self.batch['attr_blank_fill/num']
            valid_mask = tf.sequence_mask(
                num_valid_entry, maxlen=self.data_cfg.n_attr_bf,
                dtype=tf.float32)
            loss, acc, top_k_acc = \
                self.n_way_classification_loss(logit, onehot_gt, valid_mask)
            self.losses['attr_blank_fill'] = loss
            self.report['attr_blank_fill_loss'] = loss
            self.report['attr_blank_fill_acc'] = acc
            self.report['attr_blank_fill_top_{}_acc'.format(TOP_K)] = top_k_acc
Exemplo n.º 2
0
    def build(self, is_train=True):
        """
        build network architecture and loss
        """

        """
        Visual features
        """
        with tf.device('/cpu:0'):
            def load_feature(image_idx):
                selected_features = np.take(self.features, image_idx, axis=0)
                return selected_features
            V_ft = tf.py_func(
                load_feature, inp=[self.batch['image_idx']], Tout=tf.float32,
                name='sample_features')
            V_ft.set_shape([None, self.max_box_num, self.vfeat_dim])
            num_V_ft = tf.gather(self.num_boxes, self.batch['image_idx'],
                                 name='gather_num_V_ft', axis=0)
            self.mid_result['num_V_ft'] = num_V_ft
            normal_boxes = tf.gather(self.normal_boxes, self.batch['image_idx'],
                                     name='gather_normal_boxes', axis=0)
            self.mid_result['normal_boxes'] = normal_boxes

        """
        Encode question
        """
        q_embed = tf.nn.embedding_lookup(self.glove_map, self.batch['q_intseq'])
        # [bs, L_DIM]
        q_L_ft = modules.encode_L(q_embed, self.batch['q_intseq_len'], L_DIM)

        # [bs, V_DIM}
        q_map_V = modules.L2V(q_L_ft, MAP_DIM, V_DIM, is_train=is_train)

        """
        Perform attention
        """
        att_score = modules.attention(V_ft, num_V_ft, q_map_V)
        self.mid_result['att_score'] = att_score
        pooled_V_ft = modules.attention_pooling(V_ft, att_score)
        # [bs, L_DIM]
        pooled_map_L, _ = modules.V2L(pooled_V_ft, MAP_DIM, L_DIM,
                                      is_train=is_train)
        """
        Answer classification
        """
        answer_embed = tf.nn.embedding_lookup(self.glove_map, self.answer_intseq)
        # [num_answer, L_DIM]
        answer_ft = modules.encode_L(answer_embed, self.answer_intseq_len, L_DIM)

        # perform two layer feature encoding and predict output
        with tf.variable_scope('reasoning') as scope:
            log.warning(scope.name)
            # layer 1
            # answer_layer1: [1, num_answer, L_DIM]
            # pooled_layer1: [bs, 1, L_DIM]
            # q_layer1: [bs, 1, L_DIM]
            # layer1: [bs, num_answer, L_DIM]
            answer_layer1 = modules.fc_layer(
                answer_ft, L_DIM, use_bias=False, use_bn=False,
                activation_fn=None, is_training=is_train, scope='answer_layer1')
            answer_layer1 = tf.expand_dims(answer_layer1, axis=0)
            pooled_layer1 = modules.fc_layer(
                pooled_map_L, L_DIM, use_bias=False, use_bn=False,
                activation_fn=None, is_training=is_train, scope='pooled_layer1')
            pooled_layer1 = tf.expand_dims(pooled_layer1, axis=1)
            q_layer1 = modules.fc_layer(
                q_L_ft, L_DIM, use_bias=True, use_bn=False,
                activation_fn=None, is_training=is_train, scope='q_layer1')
            q_layer1 = tf.expand_dims(q_layer1, axis=1)
            layer1 = tf.tanh(answer_layer1 + pooled_layer1 + q_layer1)

            logit = modules.fc_layer(
                layer1, 1, use_bias=True, use_bn=False,
                activation_fn=None, is_training=is_train, scope='classifier')
            logit = tf.squeeze(logit, axis=-1)  # [bs, num_answer]

        """
        Compute loss and accuracy
        """
        with tf.name_scope('loss'):
            answer_target = self.batch['answer_target']
            loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=answer_target, logits=logit)
            loss = tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
            pred = tf.cast(tf.argmax(logit, axis=-1), dtype=tf.int32)
            one_hot_pred = tf.one_hot(pred, depth=self.num_answer,
                                      dtype=tf.float32)
            acc = tf.reduce_mean(
                tf.reduce_sum(one_hot_pred * answer_target, axis=-1))

            self.mid_result['pred'] = pred

            self.losses['answer'] = loss
            self.report['answer_loss'] = loss
            self.report['answer_accuracy'] = acc

        """
        Prepare image summary
        """
        """
        with tf.name_scope('prepare_summary'):
            self.vis_image['image_attention_qa'] = self.visualize_vqa_result(
                self.batch['image_id'], self.batch['box'], self.batch['num_box'],
                self.mid_result['att_score'],
                self.batch['q_intseq'], self.batch['q_intseq_len'],
                self.batch['answer_id'], self.mid_result['pred'],
                line_width=2)
        """

        self.loss = self.losses['answer']

        # scalar summary
        for key, val in self.report.items():
            tf.summary.scalar('train/{}'.format(key), val,
                              collections=['heavy_train', 'train'])
            tf.summary.scalar('val/{}'.format(key), val,
                              collections=['heavy_val', 'val'])
            tf.summary.scalar('testval/{}'.format(key), val,
                              collections=['heavy_testval', 'testval'])

        # image summary
        for key, val in self.vis_image.items():
            tf.summary.image('train-{}'.format(key), val, max_outputs=10,
                             collections=['heavy_train'])
            tf.summary.image('val-{}'.format(key), val, max_outputs=10,
                             collections=['heavy_val'])
            tf.summary.image('testval-{}'.format(key), val, max_outputs=10,
                             collections=['heavy_testval'])

        return self.loss
Exemplo n.º 3
0
    def build(self):
        """
        build network architecture and loss
        """
        """
        Visual features
        """
        with tf.device('/cpu:0'):

            def load_feature(image_idx):
                selected_features = np.take(self.features, image_idx, axis=0)
                return selected_features

            V_ft = tf.py_func(load_feature,
                              inp=[self.batch['image_idx']],
                              Tout=tf.float32,
                              name='sample_features')
            V_ft.set_shape([None, self.max_box_num, self.vfeat_dim])
            num_V_ft = tf.gather(self.num_boxes,
                                 self.batch['image_idx'],
                                 name='gather_num_V_ft',
                                 axis=0)
            self.mid_result['num_V_ft'] = num_V_ft
            normal_boxes = tf.gather(self.normal_boxes,
                                     self.batch['image_idx'],
                                     name='gather_normal_boxes',
                                     axis=0)
            self.mid_result['normal_boxes'] = normal_boxes

        log.warning('v_linear_v')
        v_linear_v = modules.fc_layer(V_ft,
                                      V_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='v_linear_v')
        """
        Encode question
        """
        q_embed = tf.nn.embedding_lookup(self.glove_map,
                                         self.batch['q_intseq'])
        # [bs, L_DIM]
        q_L_ft = modules.encode_L(q_embed,
                                  self.batch['q_intseq_len'],
                                  L_DIM,
                                  cell_type='GRU')
        q_L_mean = modules.fc_layer(q_L_ft,
                                    L_DIM,
                                    use_bias=True,
                                    use_bn=False,
                                    use_ln=False,
                                    activation_fn=None,
                                    is_training=self.is_train,
                                    scope='q_L_mean')

        # [bs, V_DIM}
        log.warning('q_linear_v')
        q_linear_v = modules.fc_layer(q_L_ft,
                                      V_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_v')
        self.mid_result['q_linear_v'] = q_linear_v
        """
        Perform attention
        """
        att_score = modules.hadamard_attention(v_linear_v,
                                               num_V_ft,
                                               q_linear_v,
                                               use_ln=False,
                                               is_train=self.is_train)
        self.mid_result['att_score'] = att_score
        pooled_V_ft = modules.attention_pooling(V_ft, att_score)
        self.mid_result['pooled_V_ft'] = pooled_V_ft
        """
        Answer classification
        """
        log.warning('pooled_linear_l')
        pooled_linear_l = modules.fc_layer(pooled_V_ft,
                                           L_DIM,
                                           use_bias=True,
                                           use_bn=False,
                                           use_ln=True,
                                           activation_fn=tf.nn.relu,
                                           is_training=self.is_train,
                                           scope='pooled_linear_l')
        self.mid_result['pooled_linear_l'] = pooled_linear_l

        log.warning('q_linear_l')
        l_linear_l = modules.fc_layer(q_L_mean,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_l')
        self.mid_result['l_linear_l'] = l_linear_l

        joint = modules.fc_layer(pooled_linear_l * l_linear_l,
                                 L_DIM * 2,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=True,
                                 activation_fn=tf.nn.relu,
                                 is_training=self.is_train,
                                 scope='joint_fc')
        joint = tf.nn.dropout(joint, 0.5)
        self.mid_result['joint'] = joint

        logit = modules.WordWeightAnswer(joint,
                                         self.answer_dict,
                                         self.word_weight_dir,
                                         use_bias=True,
                                         is_training=self.is_train,
                                         scope='WordWeightAnswer')
        self.mid_result['logit'] = logit
        """
        Compute loss and accuracy
        """
        with tf.name_scope('loss'):
            answer_target = self.batch['answer_target']
            loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=answer_target, logits=logit)

            train_loss = tf.reduce_mean(
                tf.reduce_sum(loss * self.train_answer_mask, axis=-1))
            report_loss = tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
            pred = tf.cast(tf.argmax(logit, axis=-1), dtype=tf.int32)
            one_hot_pred = tf.one_hot(pred,
                                      depth=self.num_answer,
                                      dtype=tf.float32)
            acc = tf.reduce_mean(
                tf.reduce_sum(one_hot_pred * answer_target, axis=-1))
            exist_acc = tf.reduce_mean(
                tf.reduce_sum(one_hot_pred * answer_target *
                              self.answer_exist_mask,
                              axis=-1))
            test_acc = tf.reduce_mean(
                tf.reduce_sum(one_hot_pred * answer_target *
                              self.test_answer_mask,
                              axis=-1))
            max_exist_answer_acc = tf.reduce_mean(
                tf.reduce_max(answer_target * self.answer_exist_mask, axis=-1))
            test_max_answer_acc = tf.reduce_mean(
                tf.reduce_max(answer_target * self.test_answer_mask, axis=-1))
            test_max_exist_answer_acc = tf.reduce_mean(
                tf.reduce_max(answer_target * self.answer_exist_mask *
                              self.test_answer_mask,
                              axis=-1))
            normal_test_acc = tf.where(tf.equal(test_max_answer_acc,
                                                0), test_max_answer_acc,
                                       test_acc / test_max_answer_acc)

            self.mid_result['pred'] = pred

            self.losses['answer'] = train_loss
            self.report['answer_train_loss'] = train_loss
            self.report['answer_report_loss'] = report_loss
            self.report['answer_accuracy'] = acc
            self.report['exist_answer_accuracy'] = exist_acc
            self.report['test_answer_accuracy'] = test_acc
            self.report['normal_test_answer_accuracy'] = normal_test_acc
            self.report['max_exist_answer_accuracy'] = max_exist_answer_acc
            self.report['test_max_answer_accuracy'] = test_max_answer_acc
            self.report[
                'test_max_exist_answer_accuracy'] = test_max_exist_answer_acc
        """
        Prepare image summary
        """
        """
        with tf.name_scope('prepare_summary'):
            self.vis_image['image_attention_qa'] = self.visualize_vqa_result(
                self.batch['image_id'],
                self.mid_result['normal_boxes'], self.mid_result['num_V_ft'],
                self.mid_result['att_score'],
                self.batch['q_intseq'], self.batch['q_intseq_len'],
                self.batch['answer_target'], self.mid_result['pred'],
                max_batch_num=20, line_width=2)
        """

        self.loss = 0
        for key, loss in self.losses.items():
            self.loss = self.loss + loss

        # scalar summary
        for key, val in self.report.items():
            tf.summary.scalar('train/{}'.format(key),
                              val,
                              collections=['heavy_train', 'train'])
            tf.summary.scalar('val/{}'.format(key),
                              val,
                              collections=['heavy_val', 'val'])
            tf.summary.scalar('testval/{}'.format(key),
                              val,
                              collections=['heavy_testval', 'testval'])

        # image summary
        for key, val in self.vis_image.items():
            tf.summary.image('train-{}'.format(key),
                             val,
                             max_outputs=10,
                             collections=['heavy_train'])
            tf.summary.image('val-{}'.format(key),
                             val,
                             max_outputs=10,
                             collections=['heavy_val'])
            tf.summary.image('testval-{}'.format(key),
                             val,
                             max_outputs=10,
                             collections=['heavy_testval'])

        return self.loss
    def build_object_enwiki(self):
        """
        object_enwiki
        """
        pooled_V_ft = self.mid_result['object_pooled_V_ft']

        enwiki_embed = tf.nn.embedding_lookup(  # [bs, #proposal, len, W_DIM]
            self.enwiki_map, self.batch['obj_blank_fill/enwiki_context'])
        enwiki_len = self.batch['obj_blank_fill/enwiki_context_len']
        enwiki_maxlen = tf.shape(enwiki_embed)[-2]
        flat_enwiki_ft = modules.encode_L(  # [bs * #proposal, L_DIM]
            tf.reshape(enwiki_embed, [-1, enwiki_maxlen, W_DIM]),
            tf.reshape(enwiki_len, [-1]),
            L_DIM,
            scope='encode_L_enwiki',
            cell_type='GRU')
        enwiki_ft = tf.reshape(flat_enwiki_ft,
                               [-1, self.data_cfg.n_obj_bf, L_DIM])

        v_linear_l = modules.fc_layer(pooled_V_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='pooled_linear_l')

        l_linear_l = modules.fc_layer(enwiki_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_l')

        v_joint = modules.fc_layer(v_linear_l,
                                   L_DIM * 2,
                                   use_bias=True,
                                   use_bn=False,
                                   use_ln=True,
                                   activation_fn=tf.nn.relu,
                                   is_training=self.is_train,
                                   scope='joint_v')
        v_joint = tf.nn.dropout(v_joint, 0.5)

        l_joint = modules.fc_layer(l_linear_l,
                                   L_DIM * 2,
                                   use_bias=True,
                                   use_bn=False,
                                   use_ln=True,
                                   activation_fn=tf.nn.relu,
                                   is_training=self.is_train,
                                   scope='joint_l')
        l_joint = tf.nn.dropout(l_joint, 0.5)

        v_logit = modules.fc_layer(v_joint,
                                   self.num_answer,
                                   use_bias=True,
                                   use_bn=False,
                                   use_ln=False,
                                   activation_fn=None,
                                   is_training=self.is_train,
                                   scope='classifier_v')

        l_logit = modules.fc_layer(l_joint,
                                   self.num_answer,
                                   use_bias=True,
                                   use_bn=False,
                                   use_ln=False,
                                   activation_fn=None,
                                   is_training=self.is_train,
                                   scope='classifier_l')

        with tf.name_scope('loss/obj_enwiki'):
            onehot_gt = tf.one_hot(self.batch['obj_blank_fill/fills'],
                                   depth=self.num_answer)
            num_valid_entry = self.batch['obj_blank_fill/num']
            valid_mask = tf.sequence_mask(num_valid_entry,
                                          maxlen=self.data_cfg.n_obj_bf,
                                          dtype=tf.float32)

            v_loss, v_acc, v_top_k_acc = \
                self.n_way_classification_loss(v_logit, onehot_gt, valid_mask)
            l_loss, l_acc, l_top_k_acc = \
                self.n_way_classification_loss(l_logit, onehot_gt, valid_mask)
            self.losses['obj_enwiki_v'] = v_loss
            self.losses['obj_enwiki_l'] = l_loss
            self.report['obj_enwiki_v_loss'] = v_loss
            self.report['obj_enwiki_l_loss'] = l_loss
            self.report['obj_enwiki_v_acc'] = v_acc
            self.report['obj_enwiki_l_acc'] = l_acc
            self.report['obj_enwiki_v_top_{}_acc'.format(TOP_K)] = v_top_k_acc
            self.report['obj_enwiki_l_top_{}_acc'.format(TOP_K)] = l_top_k_acc
    def build_attribute_blank_fill(self):
        """
        attribute_blank_fill
        """
        # [#obj, #proposal] x [#proposal x feat_dim] -> [#obj,feat_dim]
        V_ft = tf.matmul(self.batch['attr_blank_fill/weights'],
                         self.batch['image_ft'])
        v_linear_l = modules.fc_layer(V_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='pooled_linear_l')

        blank_embed = tf.nn.embedding_lookup(  # [bs, #proposal, len, W_DIM]
            self.l_word_map, self.batch['attr_blank_fill/blanks'])
        blank_len = self.batch['attr_blank_fill/blanks_len']
        blank_maxlen = tf.shape(blank_embed)[-2]
        flat_blank_ft = modules.encode_L(  # [bs * #proposal, L_DIM]
            tf.reshape(blank_embed, [-1, blank_maxlen, W_DIM]),
            tf.reshape(blank_len, [-1]),
            L_DIM,
            scope='encode_L_blank',
            cell_type='GRU')
        blank_ft = tf.reshape(flat_blank_ft,
                              [-1, self.data_cfg.n_obj_bf, L_DIM])

        fill_embed = tf.nn.embedding_lookup(
            self.l_answer_word_map, self.batch['attr_blank_fill/fills'])
        fill_embed2 = modules.fc_layer(fill_embed,
                                       L_DIM,
                                       use_bias=True,
                                       use_bn=False,
                                       use_ln=True,
                                       activation_fn=tf.nn.tanh,
                                       is_training=self.is_train,
                                       scope='attr_blank_fill/fill_embed2')
        blank_fill_ft = blank_ft * fill_embed2

        fill_vec = modules.fc_layer(blank_fill_ft,
                                    L_DIM,
                                    use_bias=True,
                                    use_bn=False,
                                    use_ln=False,
                                    activation_fn=None,
                                    is_training=self.is_train,
                                    scope='attr_blank_fill/fill_vec')
        fill_log_sigma_sq = modules.fc_layer(
            blank_fill_ft,
            L_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=False,
            activation_fn=None,
            is_training=self.is_train,
            scope='attr_blank_fill/fill_log_sigma_sq')
        fill_sigma = tf.sqrt(tf.exp(fill_log_sigma_sq))
        noise = tf.random_normal(tf.shape(fill_vec),
                                 mean=0,
                                 stddev=1,
                                 seed=123)
        fill_vec_noise = fill_vec + noise * fill_sigma
        self.vis_hist['attr_blank_fill/fill_vec'] = fill_vec
        self.vis_hist['attr_blank_fill/fill_sigma'] = fill_sigma

        l_linear_l = modules.fc_layer(fill_vec_noise,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_l')

        joint = modules.fc_layer(v_linear_l * l_linear_l,
                                 L_DIM * 2,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=True,
                                 activation_fn=tf.nn.relu,
                                 is_training=self.is_train,
                                 scope='joint_fc')
        joint = tf.nn.dropout(joint, 0.5)

        logit = modules.fc_layer(joint,
                                 self.num_answer,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=False,
                                 activation_fn=None,
                                 is_training=self.is_train,
                                 scope='classifier')
        self.mid_result[
            'attr_blank_fill/logit'] = logit  # [bs, #attr, #answer]

        with tf.name_scope('loss/attr_blank_fill'):
            onehot_gt = tf.one_hot(self.batch['attr_blank_fill/fills'],
                                   depth=self.num_answer)
            num_valid_entry = self.batch['attr_blank_fill/num']
            valid_mask = tf.sequence_mask(num_valid_entry,
                                          maxlen=self.data_cfg.n_attr_bf,
                                          dtype=tf.float32)
            loss, acc, top_k_acc = \
                self.n_way_classification_loss(logit, onehot_gt, valid_mask)
            latent_loss = self.latent_loss(fill_vec, fill_log_sigma_sq)
            self.losses['attr_blank_fill'] = loss
            self.losses[
                'attr_blank_fill_latent'] = self.latent_loss_weight * latent_loss
            self.report['attr_blank_fill_loss'] = loss
            self.report['attr_blank_fill_latent_loss'] = latent_loss
            self.report['attr_blank_fill_train_latent_loss'] = self.losses[
                'attr_blank_fill_latent']
            self.report['attr_blank_fill_acc'] = acc
            self.report['attr_blank_fill_top_{}_acc'.format(TOP_K)] = top_k_acc
    def build(self, is_train=True):
        """
        build network architecture and loss
        """

        """
        Visual features
        """
        # feat_V
        enc_I = modules.encode_I_block3(self.batch['image'],
                                 is_train=self.ft_enc_I)
        if not self.ft_enc_I: enc_I = tf.stop_gradient(enc_I)

        I_lowdim = modules.I_reduce_dim(enc_I, V_DIM, scope='I_reduce_dim',
                                        is_train=is_train)

        roi_ft = modules.roi_pool(I_lowdim, self.batch['normal_box'],
                                  ROI_SZ, ROI_SZ)
        visbox_flat = tf.reshape(self.batch['box'], [-1, 4])  # box for visualization

        # _flat regards "num box" dimension as a batch
        roi_ft_flat = tf.reshape(roi_ft, [-1, ROI_SZ, ROI_SZ, V_DIM])

        V_ft_flat = modules.I2V(roi_ft_flat, V_DIM, V_DIM, is_train=is_train)

        """
        Classification: object, attribute, relationship
        """
        for key in self.target_entry:
            log.info('Language: {}'.format(key))
            num_b = self.data_cfg.num_entry_box[key]
            num_k = self.data_cfg.num_k

            token = self.batch['{}_candidate'.format(key)]
            token_len = self.batch['{}_candidate_len'.format(key)]
            token_maxlen = tf.shape(token)[-1]
            # encode_L
            embed_seq = tf.nn.embedding_lookup(self.glove_map, token)
            enc_L_flat = modules.encode_L(
                tf.reshape(embed_seq, [-1, token_maxlen, W_DIM]),
                tf.reshape(token_len, [-1]), L_DIM)
            if self.no_V_grad_enc_L: enc_L_flat = tf.stop_gradient(enc_L_flat)
            enc_L = tf.reshape(enc_L_flat, [-1, num_b, num_k, L_DIM])

            # L2V mapping
            map_V = modules.L2V(enc_L, MAP_DIM, V_DIM, is_train=is_train)

            # gather target V_ft
            box_idx = modules.batch_box(self.batch['{}_box_idx'.format(key)],
                                        offset=self.data_cfg.num_box)
            box_V_ft_flat = tf.gather(V_ft_flat, tf.reshape(box_idx, [-1]))
            box_V_ft = tf.reshape(box_V_ft_flat, [-1, num_b, V_DIM])
            self.mid_result['{}_visbox'.format(key)] = tf.reshape(tf.gather(
                visbox_flat, tf.reshape(box_idx, [-1])), [-1, num_b, 4])

            # classification
            logits = modules.batch_word_classifier(box_V_ft, map_V)
            self.mid_result['{}_logits'.format(key)] = logits

            with tf.name_scope('{}_classification_loss'.format(key)):
                gt = self.batch['{}_selection_gt'.format(key)]
                num_used_box = self.batch['{}_num_used_box'.format(key)]
                used_mask = tf.sequence_mask(num_used_box, maxlen=num_b,
                                             dtype=tf.float32)
                self.mid_result['{}_used_mask'.format(key)] = used_mask

                if key == 'attribute':
                    loss, acc, recall, precision = \
                        self.binary_classification_loss(logits, gt, used_mask)
                    self.losses[key] = loss
                    self.report['{}_loss'.format(key)] = loss
                    self.report['{}_acc'.format(key)] = acc
                    self.report['{}_recall'.format(key)] = recall
                    self.report['{}_precision'.format(key)] = precision
                else:
                    loss, acc, top_k_acc = \
                        self.n_way_classification_loss(logits, gt, used_mask)
                    self.losses[key] = loss
                    self.report['{}_loss'.format(key)] = loss
                    self.report['{}_acc'.format(key)] = acc
                    self.report['{}_top_k_acc'.format(key)] = top_k_acc

        """
        Region description
        """
        # Select V_ft for descriptions
        num_desc_box = self.data_cfg.num_entry_box['region']
        desc_box_idx = modules.batch_box(self.batch['desc_box_idx'],
                                         offset=self.data_cfg.num_box)
        desc_box_V_ft_flat = tf.gather(V_ft_flat, tf.reshape(desc_box_idx, [-1]))
        desc_box_V_ft = tf.reshape(desc_box_V_ft_flat, [-1, num_desc_box, V_DIM])
        self.mid_result['region_visbox'] = tf.reshape(tf.gather(
            visbox_flat, tf.reshape(desc_box_idx, [-1])), [-1, num_desc_box, 4])

        # Metric learning
        desc = self.batch['desc']
        desc_len = self.batch['desc_len']
        desc_maxlen = tf.shape(desc)[-1]

        # encode desc_map_V
        desc_embed_seq = tf.nn.embedding_lookup(self.glove_map, desc)
        desc_L_flat = modules.encode_L(
            tf.reshape(desc_embed_seq, [-1, desc_maxlen, W_DIM]),
            tf.reshape(desc_len, [-1]), L_DIM)
        if self.no_V_grad_enc_L: desc_L_flat = tf.stop_gradient(desc_L_flat)

        desc_map_V_flat = modules.L2V(desc_L_flat, MAP_DIM, V_DIM,
                                      is_train=is_train)
        desc_map_V = tf.reshape(desc_map_V_flat, [-1, num_desc_box, V_DIM])

        # Language retrieval - for each region - classification over descriptions
        lr_num_k = self.data_cfg.lr_num_k
        lr_desc_idx_flat = modules.batch_box(
            tf.reshape(self.batch['lr_desc_idx'], [-1, num_desc_box * lr_num_k]),
            offset=num_desc_box)
        lr_map_V_flat = tf.gather(desc_map_V_flat,
                                  tf.reshape(lr_desc_idx_flat, [-1]))
        lr_map_V = tf.reshape(lr_map_V_flat, [-1, num_desc_box, lr_num_k, V_DIM])
        lr_gt = self.batch['lr_gt']

        if self.num_aug_retrieval > 0:
            lr_map_V, lr_gt = self.aug_retrieval(
                lr_map_V, lr_gt, self.num_aug_retrieval, num_desc_box,
                lr_num_k, V_DIM, scope='aug_LR')

        # Language Retrieval Classifier
        lr_logits = modules.batch_word_classifier(desc_box_V_ft, lr_map_V)
        self.mid_result['lr_logits'] = lr_logits
        self.mid_result['aug_lr_gt'] = lr_gt

        with tf.name_scope('LR_classification_loss'):
            num_used_desc = self.batch['num_used_desc']
            used_desc_mask = tf.sequence_mask(
                num_used_desc, maxlen=num_desc_box, dtype=tf.float32)
            self.mid_result['lr_used_desc_mask'] = used_desc_mask

            loss, acc, top_k_acc = self.n_way_classification_loss(
                lr_logits, lr_gt, used_desc_mask)
            self.losses['retrieval_L'] = loss
            self.report['retrieval_L_loss'] = loss
            self.report['retrieval_L_acc'] = acc
            self.report['retrieval_L_top_k_acc'] = top_k_acc

        # Image retrieval - for each description - classification over images
        ir_num_k = self.data_cfg.ir_num_k
        ir_box_idx_flat = modules.batch_box(
            tf.reshape(self.batch['ir_box_idx'], [-1, num_desc_box * ir_num_k]),
            offset=self.data_cfg.num_box)
        ir_box_V_ft_flat = tf.gather(V_ft_flat,
                                     tf.reshape(ir_box_idx_flat, [-1]))
        ir_box_V = tf.reshape(ir_box_V_ft_flat,
                              [-1, num_desc_box, ir_num_k, V_DIM])
        ir_gt = self.batch['ir_gt']
        self.mid_result['retrieval_I_visbox'] = tf.reshape(tf.gather(
            visbox_flat, tf.reshape(ir_box_idx_flat, [-1])),
            [-1, num_desc_box, ir_num_k, 4])

        if self.num_aug_retrieval > 0:
            ir_box_V, ir_gt = self.aug_retrieval(
                ir_box_V, ir_gt, self.num_aug_retrieval, num_desc_box,
                ir_num_k, V_DIM, scope='aug_IR')

        # Image Retrieval Classifier
        ir_logits = modules.batch_word_classifier(desc_map_V, ir_box_V)
        self.mid_result['ir_logits'] = ir_logits
        self.mid_result['aug_ir_gt'] = ir_gt

        with tf.name_scope('IR_classification_loss'):
            num_used_desc = self.batch['num_used_desc']
            used_desc_mask = tf.sequence_mask(
                num_used_desc, maxlen=num_desc_box, dtype=tf.float32)
            self.mid_result['ir_used_desc_mask'] = used_desc_mask

            loss, acc, top_k_acc = self.n_way_classification_loss(
                ir_logits, ir_gt, used_desc_mask)
            self.losses['retrieval_I'] = loss
            self.report['retrieval_I_loss'] = loss
            self.report['retrieval_I_acc'] = acc
            self.report['retrieval_I_top_k_acc'] = top_k_acc

        # Description / blank-fill task

        # V2L mapping
        desc_box_map_L, V2L_hidden = modules.V2L(
            desc_box_V_ft, MAP_DIM, L_DIM, is_train=is_train)
        in_L = desc_box_map_L  # language feature used for the decoding

        # Add blank-fill feature to mapped language for decoding
        if self.description_task == 'blank-fill':
            blank_desc = self.batch['blank_desc']
            blank_desc_len = self.batch['blank_desc_len']
            blank_max_len = tf.shape(blank_desc)[-1]

            blank_embed_seq = tf.nn.embedding_lookup(self.glove_map, blank_desc)
            blank_L_flat = modules.encode_L(
                tf.reshape(blank_embed_seq, [-1, blank_max_len, W_DIM]),
                tf.reshape(blank_desc_len, [-1]), L_DIM)
            blank_L = tf.reshape(blank_L_flat, [-1, num_desc_box, L_DIM])
            in_L = in_L + blank_L

        # Decode
        in_L_flat = tf.reshape(in_L, [-1, L_DIM])
        desc_flat = tf.reshape(desc, [-1, desc_maxlen])
        desc_len_flat = tf.reshape(desc_len, [-1])

        logits_flat, pred_flat, pred_len_flat = modules.decode_L(
            in_L_flat, self.decoder_dim, self.decoder_embed_map,
            self.vocab['dict']['<s>'], unroll_type='teacher_forcing',
            seq=desc_flat, seq_len=desc_len_flat + 1,
            output_layer=self.word_predictor, is_train=is_train)
        self.mid_result['pred'] = tf.reshape(
            pred_flat, [-1, num_desc_box, tf.shape(pred_flat)[-1]])
        self.mid_result['pred_len'] = tf.reshape(
            pred_len_flat, [-1, num_desc_box])

        _, greedy_flat, greedy_len_flat = modules.decode_L(
            in_L_flat, self.decoder_dim, self.decoder_embed_map,
            self.vocab['dict']['<s>'], unroll_type='greedy',
            end_token=self.vocab['dict']['<e>'],
            max_seq_len=self.data_cfg.max_len['region'] + 1,
            output_layer=self.word_predictor, is_train=is_train)
        self.mid_result['greedy'] = tf.reshape(
            greedy_flat, [-1, num_desc_box, tf.shape(greedy_flat)[-1]])
        self.mid_result['greedy_len'] = tf.reshape(
            greedy_len_flat, [-1, num_desc_box])

        with tf.name_scope('description_loss'):
            desc_used_mask = tf.sequence_mask(
                num_used_desc, maxlen=num_desc_box, dtype=tf.float32)
            self.mid_result['desc_used_mask'] = desc_used_mask
            loss = self.flat_description_loss(
                logits_flat, desc_flat, desc_len_flat + 1, desc_used_mask)
            pred_token_acc, pred_seq_acc = self.flat_description_accuracy(
                pred_flat, pred_len_flat, desc_flat, desc_len_flat + 1,
                desc_used_mask)
            greedy_token_acc, greedy_seq_acc = self.flat_description_accuracy(
                greedy_flat, greedy_len_flat, desc_flat, desc_len_flat + 1,
                desc_used_mask)

            loss *= self.decoder_loss_weight

            self.losses['description'] = loss
            self.report['description_loss'] = loss
            self.report['pred_token_acc'] = pred_token_acc
            self.report['pred_seq_acc'] = pred_seq_acc
            self.report['greedy_token_acc'] = greedy_token_acc
            self.report['greedy_seq_acc'] = greedy_seq_acc

        with tf.name_scope('prepare_summary'):
            for key in self.target_entry:
                if key == 'attribute':
                    self.vis_image['{}_classification'.format(key)] =\
                        self.vis_binary_image_classification(
                            self.batch['image'],
                            self.mid_result['{}_visbox'.format(key)],
                            self.mid_result['{}_logits'.format(key)],
                            self.batch['{}_selection_gt'.format(key)],
                            self.batch['{}_candidate'.format(key)],
                            self.batch['{}_candidate_len'.format(key)],
                            self.batch['{}_candidate_name'.format(key)],
                            self.mid_result['{}_used_mask'.format(key)],
                            vis_numbox=VIS_NUMBOX, line_width=LINE_WIDTH)
                else:
                    self.vis_image['{}_classification'.format(key)] =\
                        self.vis_n_way_image_classification(
                            self.batch['image'],
                            self.mid_result['{}_visbox'.format(key)],
                            self.mid_result['{}_logits'.format(key)],
                            self.batch['{}_selection_gt'.format(key)],
                            self.batch['{}_candidate'.format(key)],
                            self.batch['{}_candidate_len'.format(key)],
                            self.batch['{}_candidate_name'.format(key)],
                            self.mid_result['{}_used_mask'.format(key)],
                            vis_numbox=VIS_NUMBOX, line_width=LINE_WIDTH)

            self.vis_image['retrieval_L'] = self.vis_retrieval_L(
                self.batch['image'],
                self.mid_result['region_visbox'],
                self.batch['desc'],
                self.batch['desc_len'],
                self.batch['lr_desc_idx'],
                self.data_cfg.lr_num_k,
                self.mid_result['lr_logits'],
                self.mid_result['aug_lr_gt'],
                self.mid_result['lr_used_desc_mask'],
                vis_numbox=VIS_NUMBOX, line_width=LINE_WIDTH)

            self.vis_image['retrieval_I'] = self.vis_retrieval_I(
                self.batch['image'],
                self.mid_result['retrieval_I_visbox'],
                self.batch['desc'],
                self.batch['desc_len'],
                self.num_aug_retrieval,
                self.data_cfg.ir_num_k,
                self.mid_result['ir_logits'],
                self.mid_result['aug_ir_gt'],
                self.mid_result['ir_used_desc_mask'],
                vis_numbox=VIS_NUMBOX, line_width=LINE_WIDTH)

            self.vis_image['description'] = self.vis_description(
                self.batch['image'],
                self.mid_result['region_visbox'],
                self.batch['desc'],
                self.batch['desc_len'],
                self.batch['blank_desc'],
                self.batch['blank_desc_len'],
                self.description_task == 'blank-fill',
                self.mid_result['pred'],
                self.mid_result['pred_len'],
                self.mid_result['greedy'],
                self.mid_result['greedy_len'],
                self.mid_result['desc_used_mask'],
                vis_numbox=VIS_NUMBOX, line_width=LINE_WIDTH)
        # loss
        self.v_loss = 0
        self.l_loss = 0

        for key in self.target_entry:
            self.v_loss += self.losses[key]
        self.v_loss += self.losses['retrieval_L']
        self.v_loss += self.losses['retrieval_I']

        self.l_loss += self.losses['description']

        self.loss = self.v_loss + self.l_loss

        self.report['total_loss'] = self.loss
        self.report['total_v_loss'] = self.v_loss
        self.report['total_l_loss'] = self.l_loss

        # scalar summary
        for key, val in self.report.items():
            tf.summary.scalar('train/{}'.format(key), val,
                              collections=['heavy_train', 'train'])
            tf.summary.scalar('val/{}'.format(key), val,
                              collections=['heavy_val', 'val'])

        # image summary
        for key, val in self.vis_image.items():
            tf.summary.image('train-{}'.format(key), val, max_outputs=10,
                             collections=['heavy_train'])
            tf.summary.image('val-{}'.format(key), val, max_outputs=10,
                             collections=['heavy_val'])

        return self.loss