def build_attribute_V_ft(self):
        V_ft = self.batch['image_ft']  # [bs,  #proposal, #feat_dim]
        V_ft = tf.expand_dims(V_ft, axis=1)  # [bs, 1, #proposal, #feat_dim]
        V_ft = tf.tile(V_ft, [1, self.data_cfg.n_attr_bf, 1, 1])  # [bs, #attr, #proposal, #feat_dim]
        V_ft = tf.reshape(
            V_ft, [-1, self.data_cfg.max_box_num, self.data_cfg.vfeat_dim])  # [bs * #attr, #proposal, #feat_dim]
        spat_ft = self.batch['spatial_ft']
        spat_ft = tf.expand_dims(spat_ft, axis=1)
        spat_ft = tf.tile(spat_ft, [1, self.data_cfg.n_attr_bf, 1, 1])
        spat_ft = tf.reshape(
            spat_ft, [-1, self.data_cfg.max_box_num, 6])
        num_V_ft = self.batch['num_boxes']  # [bs]
        num_V_ft = tf.expand_dims(num_V_ft, axis=1)  # [bs, 1]
        num_V_ft = tf.tile(num_V_ft, [1, self.data_cfg.n_attr_bf])  # [bs, #attr]
        num_V_ft = tf.reshape(num_V_ft, [-1])  # [bs * #attr]

        key_spat_ft = self.batch['attr_blank_fill/normal_boxes']
        key_spat_ft = tf.concat(
            [key_spat_ft,
             tf.expand_dims(key_spat_ft[:, :, 2] - key_spat_ft[:, :, 0], axis=-1),
             tf.expand_dims(key_spat_ft[:, :, 3] - key_spat_ft[:, :, 1], axis=-1)],
            axis=-1)

        v_linear_v = modules.fc_layer(  # [bs * #obj, #proposal, V_DIM]
            spat_ft, V_DIM, use_bias=True, use_bn=False, use_ln=True,
            activation_fn=tf.nn.relu, is_training=self.is_train,
            scope='spat_v_linear_v')

        q_linear_v = modules.fc_layer(  # [bs, #obj, V_DIM]
            key_spat_ft, V_DIM, use_bias=True, use_bn=False, use_ln=True,
            activation_fn=tf.nn.relu, is_training=self.is_train,
            scope='spat_q_linear_v')
        flat_q_linear_v = tf.reshape(q_linear_v, [-1, V_DIM])  # [bs * #obj, V_DIM]

        att_score = modules.hadamard_attention(  # [bs * #obj, len]
            v_linear_v, num_V_ft, flat_q_linear_v,
            use_ln=False, is_train=self.is_train, scope='spat_att')
        flat_pooled_V_ft = modules.attention_pooling(V_ft, att_score)  # [bs * #obj, vfeat_dim]
        pooled_V_ft = tf.reshape(
            flat_pooled_V_ft, [-1, self.data_cfg.n_attr_bf, self.data_cfg.vfeat_dim])

        self.mid_result['attribute_pooled_V_ft'] = pooled_V_ft
Beispiel #2
0
    def build(self, is_train=True):
        """
        build network architecture and loss
        """

        """
        Visual features
        """
        with tf.device('/cpu:0'):
            def load_feature(image_idx):
                selected_features = np.take(self.features, image_idx, axis=0)
                return selected_features
            V_ft = tf.py_func(
                load_feature, inp=[self.batch['image_idx']], Tout=tf.float32,
                name='sample_features')
            V_ft.set_shape([None, self.max_box_num, self.vfeat_dim])
            num_V_ft = tf.gather(self.num_boxes, self.batch['image_idx'],
                                 name='gather_num_V_ft', axis=0)
            self.mid_result['num_V_ft'] = num_V_ft
            normal_boxes = tf.gather(self.normal_boxes, self.batch['image_idx'],
                                     name='gather_normal_boxes', axis=0)
            self.mid_result['normal_boxes'] = normal_boxes

        """
        Encode question
        """
        q_embed = tf.nn.embedding_lookup(self.glove_map, self.batch['q_intseq'])
        # [bs, L_DIM]
        q_L_ft = modules.encode_L(q_embed, self.batch['q_intseq_len'], L_DIM)

        # [bs, V_DIM}
        q_map_V = modules.L2V(q_L_ft, MAP_DIM, V_DIM, is_train=is_train)

        """
        Perform attention
        """
        att_score = modules.attention(V_ft, num_V_ft, q_map_V)
        self.mid_result['att_score'] = att_score
        pooled_V_ft = modules.attention_pooling(V_ft, att_score)
        # [bs, L_DIM]
        pooled_map_L, _ = modules.V2L(pooled_V_ft, MAP_DIM, L_DIM,
                                      is_train=is_train)
        """
        Answer classification
        """
        answer_embed = tf.nn.embedding_lookup(self.glove_map, self.answer_intseq)
        # [num_answer, L_DIM]
        answer_ft = modules.encode_L(answer_embed, self.answer_intseq_len, L_DIM)

        # perform two layer feature encoding and predict output
        with tf.variable_scope('reasoning') as scope:
            log.warning(scope.name)
            # layer 1
            # answer_layer1: [1, num_answer, L_DIM]
            # pooled_layer1: [bs, 1, L_DIM]
            # q_layer1: [bs, 1, L_DIM]
            # layer1: [bs, num_answer, L_DIM]
            answer_layer1 = modules.fc_layer(
                answer_ft, L_DIM, use_bias=False, use_bn=False,
                activation_fn=None, is_training=is_train, scope='answer_layer1')
            answer_layer1 = tf.expand_dims(answer_layer1, axis=0)
            pooled_layer1 = modules.fc_layer(
                pooled_map_L, L_DIM, use_bias=False, use_bn=False,
                activation_fn=None, is_training=is_train, scope='pooled_layer1')
            pooled_layer1 = tf.expand_dims(pooled_layer1, axis=1)
            q_layer1 = modules.fc_layer(
                q_L_ft, L_DIM, use_bias=True, use_bn=False,
                activation_fn=None, is_training=is_train, scope='q_layer1')
            q_layer1 = tf.expand_dims(q_layer1, axis=1)
            layer1 = tf.tanh(answer_layer1 + pooled_layer1 + q_layer1)

            logit = modules.fc_layer(
                layer1, 1, use_bias=True, use_bn=False,
                activation_fn=None, is_training=is_train, scope='classifier')
            logit = tf.squeeze(logit, axis=-1)  # [bs, num_answer]

        """
        Compute loss and accuracy
        """
        with tf.name_scope('loss'):
            answer_target = self.batch['answer_target']
            loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=answer_target, logits=logit)
            loss = tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
            pred = tf.cast(tf.argmax(logit, axis=-1), dtype=tf.int32)
            one_hot_pred = tf.one_hot(pred, depth=self.num_answer,
                                      dtype=tf.float32)
            acc = tf.reduce_mean(
                tf.reduce_sum(one_hot_pred * answer_target, axis=-1))

            self.mid_result['pred'] = pred

            self.losses['answer'] = loss
            self.report['answer_loss'] = loss
            self.report['answer_accuracy'] = acc

        """
        Prepare image summary
        """
        """
        with tf.name_scope('prepare_summary'):
            self.vis_image['image_attention_qa'] = self.visualize_vqa_result(
                self.batch['image_id'], self.batch['box'], self.batch['num_box'],
                self.mid_result['att_score'],
                self.batch['q_intseq'], self.batch['q_intseq_len'],
                self.batch['answer_id'], self.mid_result['pred'],
                line_width=2)
        """

        self.loss = self.losses['answer']

        # scalar summary
        for key, val in self.report.items():
            tf.summary.scalar('train/{}'.format(key), val,
                              collections=['heavy_train', 'train'])
            tf.summary.scalar('val/{}'.format(key), val,
                              collections=['heavy_val', 'val'])
            tf.summary.scalar('testval/{}'.format(key), val,
                              collections=['heavy_testval', 'testval'])

        # image summary
        for key, val in self.vis_image.items():
            tf.summary.image('train-{}'.format(key), val, max_outputs=10,
                             collections=['heavy_train'])
            tf.summary.image('val-{}'.format(key), val, max_outputs=10,
                             collections=['heavy_val'])
            tf.summary.image('testval-{}'.format(key), val, max_outputs=10,
                             collections=['heavy_testval'])

        return self.loss
Beispiel #3
0
    def build(self):
        """
        build network architecture and loss
        """
        """
        Visual features
        """
        with tf.device('/cpu:0'):

            def load_feature(image_idx):
                selected_features = np.take(self.features, image_idx, axis=0)
                return selected_features

            V_ft = tf.py_func(load_feature,
                              inp=[self.batch['image_idx']],
                              Tout=tf.float32,
                              name='sample_features')
            V_ft.set_shape([None, self.max_box_num, self.vfeat_dim])
            num_V_ft = tf.gather(self.num_boxes,
                                 self.batch['image_idx'],
                                 name='gather_num_V_ft',
                                 axis=0)
            self.mid_result['num_V_ft'] = num_V_ft
            normal_boxes = tf.gather(self.normal_boxes,
                                     self.batch['image_idx'],
                                     name='gather_normal_boxes',
                                     axis=0)
            self.mid_result['normal_boxes'] = normal_boxes

        log.warning('v_linear_v')
        v_linear_v = modules.fc_layer(V_ft,
                                      V_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='v_linear_v')
        """
        Encode question
        """
        q_embed = tf.nn.embedding_lookup(self.glove_map,
                                         self.batch['q_intseq'])
        # [bs, L_DIM]
        q_L_ft = modules.encode_L(q_embed,
                                  self.batch['q_intseq_len'],
                                  L_DIM,
                                  cell_type='GRU')
        q_L_mean = modules.fc_layer(q_L_ft,
                                    L_DIM,
                                    use_bias=True,
                                    use_bn=False,
                                    use_ln=False,
                                    activation_fn=None,
                                    is_training=self.is_train,
                                    scope='q_L_mean')

        # [bs, V_DIM}
        log.warning('q_linear_v')
        q_linear_v = modules.fc_layer(q_L_ft,
                                      V_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_v')
        self.mid_result['q_linear_v'] = q_linear_v
        """
        Perform attention
        """
        att_score = modules.hadamard_attention(v_linear_v,
                                               num_V_ft,
                                               q_linear_v,
                                               use_ln=False,
                                               is_train=self.is_train)
        self.mid_result['att_score'] = att_score
        pooled_V_ft = modules.attention_pooling(V_ft, att_score)
        self.mid_result['pooled_V_ft'] = pooled_V_ft
        """
        Answer classification
        """
        log.warning('pooled_linear_l')
        pooled_linear_l = modules.fc_layer(pooled_V_ft,
                                           L_DIM,
                                           use_bias=True,
                                           use_bn=False,
                                           use_ln=True,
                                           activation_fn=tf.nn.relu,
                                           is_training=self.is_train,
                                           scope='pooled_linear_l')
        self.mid_result['pooled_linear_l'] = pooled_linear_l

        log.warning('q_linear_l')
        l_linear_l = modules.fc_layer(q_L_mean,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_l')
        self.mid_result['l_linear_l'] = l_linear_l

        joint = modules.fc_layer(pooled_linear_l * l_linear_l,
                                 L_DIM * 2,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=True,
                                 activation_fn=tf.nn.relu,
                                 is_training=self.is_train,
                                 scope='joint_fc')
        joint = tf.nn.dropout(joint, 0.5)
        self.mid_result['joint'] = joint

        logit = modules.WordWeightAnswer(joint,
                                         self.answer_dict,
                                         self.word_weight_dir,
                                         use_bias=True,
                                         is_training=self.is_train,
                                         scope='WordWeightAnswer')
        self.mid_result['logit'] = logit
        """
        Compute loss and accuracy
        """
        with tf.name_scope('loss'):
            answer_target = self.batch['answer_target']
            loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=answer_target, logits=logit)

            train_loss = tf.reduce_mean(
                tf.reduce_sum(loss * self.train_answer_mask, axis=-1))
            report_loss = tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
            pred = tf.cast(tf.argmax(logit, axis=-1), dtype=tf.int32)
            one_hot_pred = tf.one_hot(pred,
                                      depth=self.num_answer,
                                      dtype=tf.float32)
            acc = tf.reduce_mean(
                tf.reduce_sum(one_hot_pred * answer_target, axis=-1))
            exist_acc = tf.reduce_mean(
                tf.reduce_sum(one_hot_pred * answer_target *
                              self.answer_exist_mask,
                              axis=-1))
            test_acc = tf.reduce_mean(
                tf.reduce_sum(one_hot_pred * answer_target *
                              self.test_answer_mask,
                              axis=-1))
            max_exist_answer_acc = tf.reduce_mean(
                tf.reduce_max(answer_target * self.answer_exist_mask, axis=-1))
            test_max_answer_acc = tf.reduce_mean(
                tf.reduce_max(answer_target * self.test_answer_mask, axis=-1))
            test_max_exist_answer_acc = tf.reduce_mean(
                tf.reduce_max(answer_target * self.answer_exist_mask *
                              self.test_answer_mask,
                              axis=-1))
            normal_test_acc = tf.where(tf.equal(test_max_answer_acc,
                                                0), test_max_answer_acc,
                                       test_acc / test_max_answer_acc)

            self.mid_result['pred'] = pred

            self.losses['answer'] = train_loss
            self.report['answer_train_loss'] = train_loss
            self.report['answer_report_loss'] = report_loss
            self.report['answer_accuracy'] = acc
            self.report['exist_answer_accuracy'] = exist_acc
            self.report['test_answer_accuracy'] = test_acc
            self.report['normal_test_answer_accuracy'] = normal_test_acc
            self.report['max_exist_answer_accuracy'] = max_exist_answer_acc
            self.report['test_max_answer_accuracy'] = test_max_answer_acc
            self.report[
                'test_max_exist_answer_accuracy'] = test_max_exist_answer_acc
        """
        Prepare image summary
        """
        """
        with tf.name_scope('prepare_summary'):
            self.vis_image['image_attention_qa'] = self.visualize_vqa_result(
                self.batch['image_id'],
                self.mid_result['normal_boxes'], self.mid_result['num_V_ft'],
                self.mid_result['att_score'],
                self.batch['q_intseq'], self.batch['q_intseq_len'],
                self.batch['answer_target'], self.mid_result['pred'],
                max_batch_num=20, line_width=2)
        """

        self.loss = 0
        for key, loss in self.losses.items():
            self.loss = self.loss + loss

        # scalar summary
        for key, val in self.report.items():
            tf.summary.scalar('train/{}'.format(key),
                              val,
                              collections=['heavy_train', 'train'])
            tf.summary.scalar('val/{}'.format(key),
                              val,
                              collections=['heavy_val', 'val'])
            tf.summary.scalar('testval/{}'.format(key),
                              val,
                              collections=['heavy_testval', 'testval'])

        # image summary
        for key, val in self.vis_image.items():
            tf.summary.image('train-{}'.format(key),
                             val,
                             max_outputs=10,
                             collections=['heavy_train'])
            tf.summary.image('val-{}'.format(key),
                             val,
                             max_outputs=10,
                             collections=['heavy_val'])
            tf.summary.image('testval-{}'.format(key),
                             val,
                             max_outputs=10,
                             collections=['heavy_testval'])

        return self.loss
Beispiel #4
0
    def build_attribute_blank_fill(self):
        """
        attribute_blank_fill
        """
        V_ft = self.mid_result['attribute_V_ft']
        num_V_ft = self.mid_result['attribute_num_V_ft']

        v_linear_v = modules.fc_layer(  # [bs * #attr, #proposal, V_DIM]
            V_ft,
            V_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=True,
            activation_fn=tf.nn.relu,
            is_training=self.is_train,
            scope='bf_v_linear_v')

        blank_embed = tf.nn.embedding_lookup(  # [bs, #proposal, len, W_DIM]
            self.l_word_map, self.batch['attr_blank_fill/blanks'])
        blank_len = self.batch['attr_blank_fill/blanks_len']
        blank_maxlen = tf.shape(blank_embed)[-2]
        flat_blank_ft = modules.encode_L(  # [bs * #proposal, L_DIM]
            tf.reshape(blank_embed, [-1, blank_maxlen, W_DIM]),
            tf.reshape(blank_len, [-1]),
            L_DIM,
            scope='encode_L_blank',
            cell_type='GRU')
        blank_ft = tf.reshape(flat_blank_ft,
                              [-1, self.data_cfg.n_attr_bf, L_DIM])

        q_linear_v = modules.fc_layer(  # [bs, #attr, V_DIM]
            blank_ft,
            V_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=True,
            activation_fn=tf.nn.relu,
            is_training=self.is_train,
            scope='bf_q_linear_v')
        flat_q_linear_v = tf.reshape(q_linear_v,
                                     [-1, V_DIM])  # [bs * #attr, V_DIM]

        att_score = modules.hadamard_attention(  # [bs * #attr, len]
            v_linear_v,
            num_V_ft,
            flat_q_linear_v,
            use_ln=False,
            is_train=self.is_train,
            scope='bf_att')
        flat_pooled_V_ft = modules.attention_pooling(
            V_ft, att_score)  # [bs * #attr, V_DIM]
        pooled_V_ft = tf.reshape(
            flat_pooled_V_ft,
            [-1, self.data_cfg.n_attr_bf, self.data_cfg.vfeat_dim])

        v_linear_l = modules.fc_layer(pooled_V_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='pooled_linear_l')

        l_linear_l = modules.fc_layer(blank_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_l')

        joint = modules.fc_layer(v_linear_l * l_linear_l,
                                 L_DIM * 2,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=True,
                                 activation_fn=tf.nn.relu,
                                 is_training=self.is_train,
                                 scope='joint_fc')
        joint = tf.nn.dropout(joint, 0.5)

        logit = modules.fc_layer(joint,
                                 self.num_answer,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=False,
                                 activation_fn=None,
                                 is_training=self.is_train,
                                 scope='classifier')
        self.mid_result[
            'attr_blank_fill/logit'] = logit  # [bs, #attr, #answer]

        with tf.name_scope('loss/attr_blank_fill'):
            onehot_gt = tf.one_hot(self.batch['attr_blank_fill/fills'],
                                   depth=self.num_answer)
            num_valid_entry = self.batch['attr_blank_fill/num']
            valid_mask = tf.sequence_mask(num_valid_entry,
                                          maxlen=self.data_cfg.n_attr_bf,
                                          dtype=tf.float32)
            loss, acc, top_k_acc = \
                self.n_way_classification_loss(logit, onehot_gt, valid_mask)
            self.losses['attr_blank_fill'] = loss
            self.report['attr_blank_fill_loss'] = loss
            self.report['attr_blank_fill_acc'] = acc
            self.report['attr_blank_fill_top_{}_acc'.format(TOP_K)] = top_k_acc
Beispiel #5
0
    def build_object_wordset(self):
        """
        object_wordset
        """
        V_ft = self.mid_result['object_V_ft']
        num_V_ft = self.mid_result['object_num_V_ft']

        v_linear_v = modules.fc_layer(  # [bs * #obj, #proposal, V_DIM]
            V_ft,
            V_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=True,
            activation_fn=tf.nn.relu,
            is_training=self.is_train,
            scope='wordset_v_linear_v')

        wordset_embed = tf.tanh(
            tf.nn.embedding_lookup(  # [bs, #obj, W_DIM]
                self.wordset_map, self.batch['obj_blank_fill/wordsets']))
        wordset_ft = modules.fc_layer(  # [bs, #obj, L_DIM]
            wordset_embed,
            L_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=True,
            activation_fn=tf.tanh,
            is_training=self.is_train,
            scope='wordset_ft')

        q_linear_v = modules.fc_layer(  # [bs, #obj, V_DIM]
            wordset_ft,
            V_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=True,
            activation_fn=tf.nn.relu,
            is_training=self.is_train,
            scope='wordset_q_linear_v')
        flat_q_linear_v = tf.reshape(q_linear_v,
                                     [-1, V_DIM])  # [bs * #obj, V_DIM]

        att_score = modules.hadamard_attention(  # [bs * #obj, len]
            v_linear_v,
            num_V_ft,
            flat_q_linear_v,
            use_ln=False,
            is_train=self.is_train,
            scope='wordset_att')
        flat_pooled_V_ft = modules.attention_pooling(
            V_ft, att_score)  # [bs * #obj, vfeat_dim]
        pooled_V_ft = tf.reshape(
            flat_pooled_V_ft,
            [-1, self.data_cfg.n_obj_bf, self.data_cfg.vfeat_dim])

        v_linear_l = modules.fc_layer(pooled_V_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='pooled_linear_l')

        l_linear_l = modules.fc_layer(wordset_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_l')

        joint = modules.fc_layer(v_linear_l * l_linear_l,
                                 L_DIM * 2,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=True,
                                 activation_fn=tf.nn.relu,
                                 is_training=self.is_train,
                                 scope='joint_fc')
        joint = tf.nn.dropout(joint, 0.5)

        logit = modules.fc_layer(joint,
                                 self.num_answer,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=False,
                                 activation_fn=None,
                                 is_training=self.is_train,
                                 scope='classifier')
        self.mid_result['obj_blank_fill/logit'] = logit  # [bs, #obj, #answer]

        with tf.name_scope('loss/obj_wordset'):
            onehot_gt = tf.one_hot(self.batch['obj_blank_fill/fills'],
                                   depth=self.num_answer)
            num_valid_entry = self.batch['obj_blank_fill/num']
            valid_mask = tf.sequence_mask(num_valid_entry,
                                          maxlen=self.data_cfg.n_obj_bf,
                                          dtype=tf.float32)
            loss, acc, top_k_acc = \
                self.n_way_classification_loss(logit, onehot_gt, valid_mask)
            self.losses['obj_wordset'] = loss
            self.report['obj_wordset_loss'] = loss
            self.report['obj_wordset_acc'] = acc
            self.report['obj_wordset_top_{}_acc'.format(TOP_K)] = top_k_acc
    def build_object_predict(self):
        """
        object_predict
        """
        V_ft = self.batch['image_ft']  # [bs,  #proposal, #feat_dim]
        V_ft = tf.expand_dims(V_ft, axis=1)  # [bs, 1, #proposal, #feat_dim]
        V_ft = tf.tile(V_ft, [1, self.data_cfg.n_obj_pred, 1, 1
                              ])  # [bs, #obj, #proposal, #feat_dim]
        V_ft = tf.reshape(
            V_ft, [-1, self.data_cfg.max_box_num, self.data_cfg.vfeat_dim
                   ])  # [bs * #obj, #proposal, #feat_dim]
        spat_ft = self.batch['spatial_ft']
        spat_ft = tf.expand_dims(spat_ft, axis=1)
        spat_ft = tf.tile(spat_ft, [1, self.data_cfg.n_obj_pred, 1, 1])
        spat_ft = tf.reshape(spat_ft, [-1, self.data_cfg.max_box_num, 6])
        num_V_ft = self.batch['num_boxes']  # [bs]
        num_V_ft = tf.expand_dims(num_V_ft, axis=1)  # [bs, 1]
        num_V_ft = tf.tile(num_V_ft,
                           [1, self.data_cfg.n_obj_pred])  # [bs, #obj]
        num_V_ft = tf.reshape(num_V_ft, [-1])  # [bs * #obj]

        key_spat_ft = self.batch['obj_pred/normal_boxes']
        key_spat_ft = tf.concat([
            key_spat_ft,
            tf.expand_dims(key_spat_ft[:, :, 2] - key_spat_ft[:, :, 0],
                           axis=-1),
            tf.expand_dims(key_spat_ft[:, :, 3] - key_spat_ft[:, :, 1],
                           axis=-1)
        ],
                                axis=-1)

        v_linear_v = modules.fc_layer(  # [bs * #obj, #proposal, V_DIM]
            spat_ft,
            V_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=True,
            activation_fn=tf.nn.relu,
            is_training=self.is_train,
            scope='spat_v_linear_v')

        q_linear_v = modules.fc_layer(  # [bs, #obj, V_DIM]
            key_spat_ft,
            V_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=True,
            activation_fn=tf.nn.relu,
            is_training=self.is_train,
            scope='spat_q_linear_v')
        flat_q_linear_v = tf.reshape(q_linear_v,
                                     [-1, V_DIM])  # [bs * #obj, V_DIM]

        att_score = modules.hadamard_attention(  # [bs * #obj, len]
            v_linear_v,
            num_V_ft,
            flat_q_linear_v,
            use_ln=False,
            is_train=self.is_train,
            scope='spat_att')
        flat_pooled_V_ft = modules.attention_pooling(
            V_ft, att_score)  # [bs * #obj, vfeat_dim]
        pooled_V_ft = tf.reshape(
            flat_pooled_V_ft,
            [-1, self.data_cfg.n_obj_pred, self.data_cfg.vfeat_dim])

        wordset_embed = tf.tanh(
            tf.nn.embedding_lookup(  # [bs, #obj, W_DIM]
                self.wordset_map, self.batch['obj_pred/wordsets']))
        wordset_ft = modules.fc_layer(  # [bs, #obj, L_DIM]
            wordset_embed,
            L_DIM,
            use_bias=True,
            use_bn=False,
            use_ln=True,
            activation_fn=tf.tanh,
            is_training=self.is_train,
            scope='wordset_ft')

        v_linear_l = modules.fc_layer(pooled_V_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='pooled_linear_l')

        l_linear_l = modules.fc_layer(wordset_ft,
                                      L_DIM,
                                      use_bias=True,
                                      use_bn=False,
                                      use_ln=True,
                                      activation_fn=tf.nn.relu,
                                      is_training=self.is_train,
                                      scope='q_linear_l')

        joint = modules.fc_layer(v_linear_l * l_linear_l,
                                 L_DIM * 2,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=True,
                                 activation_fn=tf.nn.relu,
                                 is_training=self.is_train,
                                 scope='joint_fc')
        joint = tf.nn.dropout(joint, 0.5)

        logit = modules.fc_layer(joint,
                                 self.num_answer,
                                 use_bias=True,
                                 use_bn=False,
                                 use_ln=False,
                                 activation_fn=None,
                                 is_training=self.is_train,
                                 scope='classifier')
        self.mid_result['obj_pred/logit'] = logit  # [bs, #obj, #answer]

        with tf.name_scope('loss/object_predict'):
            onehot_gt = tf.one_hot(self.batch['obj_pred/labels'],
                                   depth=self.num_answer)
            num_valid_entry = self.batch['obj_pred/num']
            valid_mask = tf.sequence_mask(num_valid_entry,
                                          maxlen=self.data_cfg.n_obj_pred,
                                          dtype=tf.float32)
            loss, acc, top_k_acc = \
                self.n_way_classification_loss(logit, onehot_gt, valid_mask)
            self.losses['object_pred'] = loss
            self.report['object_pred_loss'] = loss
            self.report['object_pred_acc'] = acc
            self.report['object_pred_top_{}_acc'.format(TOP_K)] = top_k_acc