def forward(self, features, all_phrase_ids, targets, precomp_boxes, precomp_score,
                precomp_det_label, image_scale, all_sent_sgs, all_sentences, image_unique_id, det_label_embedding):

        """
        :param obj_proposals: proposal from each images
        :param features: features maps from the backbone
        :param target: gt relation labels
        :param object_vocab, object_vocab_len [[xxx,xxx],[xxx],[xxx]], [2,1,1]
        :param sent_sg: sentence scene graph
        :return: prediction, loss

        note that first dimension is images
        """
        img_num_per_gpu = len(features)

        batch_decode_logits = []
        batch_topk_decoder_logits = []
        batch_pred_similarity = []
        batch_precomp_boxes = []
        batch_topk_precomp_boxes=[]
        batch_pred_boxes = []
        batch_topk_pred_boxes = []
        batch_topk_fusion_pred_boxes = []
        batch_topk_pred_similarity = []
        batch_topk_fusion_similarity = []
        batch_boxes_targets = []
        batch_ctx_embed = []
        batch_ctx_s1_embed = []

        batch_pred_targets = []
        batch_topk_pred_targets = []


        """ Language Embedding"""
        batch_phrase_ids, batch_phrase_types, batch_phrase_embed, batch_phrase_len, \
        batch_phrase_dec_ids, batch_phrase_mask, batch_decoder_word_embed, batch_phrase_glove_embed, batch_rel_phrase_embed, batch_relation_conn, batch_sent_embed,\
        batch_decoder_rel_word_embed, batch_rel_mask, batch_rel_dec_idx = self.phrase_embed(all_sentences, all_phrase_ids, all_sent_sgs)

        h, w = features.shape[-2:]

        # self.storage = get_event_storage()


        for bid in range(img_num_per_gpu):

            """ Visual Embedding """
            precomp_boxes_bid = precomp_boxes[bid].to(self.device)  ## 100*4

            order = []
            for phr_ids in batch_phrase_ids[bid]:
                order.append(all_phrase_ids[bid].index(phr_ids))
            target_filter = targets[bid][np.array(order)]
            batch_boxes_targets.append(target_filter.to(self.device))
            batch_precomp_boxes.append(precomp_boxes_bid)

            img_feat_bid = features[[bid]]
            visual_features_bid = self.rcnn_top(self.det_roi_pooler([img_feat_bid], [precomp_boxes_bid])).mean(dim=[2, 3]).contiguous()
            if cfg.MODEL.VG.SPATIAL_FEAT:
                spa_feat = meshgrid_generation(h, w)
                spa_feat = self.det_roi_pooler([spa_feat], [precomp_boxes_bid]).view(visual_features_bid.shape[0], -1)
                spa_feat = self.spatial_trans(spa_feat)
                visual_features_bid = torch.cat((visual_features_bid, spa_feat), dim=1)

            visual_features_bid = self.visual_embedding(visual_features_bid)
            visual_features_bid = self.vis_batchnorm(visual_features_bid)

            """ Noun Phrase embedding """
            phrase_embed_bid = batch_phrase_embed[bid]
            if phrase_embed_bid.shape[0] == 1 and self.training:
                phrase_embed_bid = self.phr_batchnorm(phrase_embed_bid.repeat(2,1))[[0]]
            else:
                phrase_embed_bid = self.phr_batchnorm(phrase_embed_bid)


            """ Similarity and attention prediction """
            num_box = precomp_boxes_bid.tensor.size(0)
            num_phrase = phrase_embed_bid.size(0)
            phr_inds, obj_inds = self.make_pair(num_phrase, num_box)
            pred_similarity_bid, pred_targets_bid = self.similarity(visual_features_bid, phrase_embed_bid, obj_inds, phr_inds)
            pred_similarity_bid = pred_similarity_bid.reshape(num_phrase, num_box)
            pred_targets_bid = pred_targets_bid.reshape(num_phrase, num_box, 4)
            batch_pred_targets.append(pred_targets_bid)


            if cfg.MODEL.VG.USING_DET_KNOWLEDGE :
                det_label_embedding_bid = det_label_embedding[bid].to(self.device)
                sim = self.cal_det_label_sim_max(det_label_embedding_bid, batch_phrase_glove_embed[bid])
                pred_similarity_bid = pred_similarity_bid * sim
                sim_mask = (sim > 0).float()
                atten_bid = numerical_stability_masked_softmax(pred_similarity_bid, sim_mask, dim=1)
            else:
                atten_bid = F.softmax(pred_similarity_bid, dim=1)

            ## reconstruction visual features
            visual_reconst_bid = torch.mm(atten_bid, visual_features_bid)
            decode_phr_logits = self.phrase_decoder(visual_reconst_bid, batch_decoder_word_embed[bid])
            batch_decode_logits.append(decode_phr_logits)

            atten_score_topk, atten_ranking_topk = torch.topk(atten_bid, dim=1, k=self.s2_topk) ## (N, 10)
            ind_phr_topk = np.arange(num_phrase).repeat(self.s2_topk)


            ## -----------------------------------------------------##
            ## crop 2st features
            ## -----------------------------------------------------##

            if self.storage.iter <= cfg.SOLVER.REG_START_ITER:
                visual_features_topk_bid = visual_features_bid[atten_ranking_topk.reshape(-1)]
                precomp_boxes_topk_bid = precomp_boxes_bid[atten_ranking_topk.reshape(-1)]
                batch_topk_precomp_boxes.append(precomp_boxes_topk_bid)
            else:
                topk_box_ids = atten_ranking_topk.reshape(-1) + torch.as_tensor(ind_phr_topk, dtype=torch.long).to(self.device)*num_box
                precomp_boxes_tensor, box_size = precomp_boxes_bid.tensor, precomp_boxes_bid.size
                precomp_boxes_topk_tensor = precomp_boxes_tensor[atten_ranking_topk.reshape(-1)]  ## (N*10, 4)
                pred_targets_s0 = pred_targets_bid.view(-1, 4)[topk_box_ids]
                precomp_boxes_topk_bid = self.box2box_translation.apply_deltas(pred_targets_s0, precomp_boxes_topk_tensor)
                precomp_boxes_topk_bid = Boxes(precomp_boxes_topk_bid, box_size)
                precomp_boxes_topk_bid.clip()
                batch_topk_precomp_boxes.append(precomp_boxes_topk_bid)
                visual_features_topk_bid = self.rcnn_top(self.det_roi_pooler([img_feat_bid], [precomp_boxes_topk_bid])).mean(dim=[2, 3]).contiguous()

                if cfg.MODEL.VG.SPATIAL_FEAT:
                    spa_feat = meshgrid_generation(h, w)
                    spa_feat = self.det_roi_pooler([spa_feat], [precomp_boxes_topk_bid]).view(visual_features_topk_bid.shape[0], -1)
                    spa_feat = self.spatial_trans(spa_feat)
                    visual_features_topk_bid = torch.cat((visual_features_topk_bid, spa_feat), dim=1)

                visual_features_topk_bid = self.visual_embedding(visual_features_topk_bid)## (N*10, 1024)
                visual_features_topk_bid = self.vis_batchnorm(visual_features_topk_bid)


            pred_similarity_topk_bid, pred_targets_topk_bid = self.similarity_topk(visual_features_topk_bid, phrase_embed_bid, ind_phr_topk)
            pred_similarity_topk_bid = pred_similarity_topk_bid.reshape(num_phrase, self.s2_topk)
            pred_targets_topk_bid = pred_targets_topk_bid.reshape(num_phrase, self.s2_topk, 4)
            batch_topk_pred_targets.append(pred_targets_topk_bid)


            if cfg.MODEL.VG.USING_DET_KNOWLEDGE:
                sim_topk = torch.gather(sim, dim=1, index=atten_ranking_topk.long())
                sim_mask = (sim_topk>0).float()
                pred_similarity_topk_bid = pred_similarity_topk_bid * sim_topk
                atten_topk_bid = numerical_stability_masked_softmax(pred_similarity_topk_bid, sim_mask, dim=1)
            else:
                atten_topk_bid = F.softmax(pred_similarity_topk_bid, dim=1)

            atten_fusion = atten_topk_bid * atten_score_topk  ## N*10
            visual_features_topk_bid = visual_features_topk_bid.view(num_phrase, self.s2_topk, -1)
            visual_reconst_topk_bid = (atten_fusion.unsqueeze(2)*visual_features_topk_bid).sum(1) ## N*1024
            decoder_phr_topk_logits = self.phrase_decoder(visual_reconst_topk_bid, batch_decoder_word_embed[bid])
            batch_topk_decoder_logits.append(decoder_phr_topk_logits)


            ## construct the discriminative loss
            batch_ctx_s1_embed.append(self.visual_mlp(visual_reconst_bid.mean(0, keepdim=True)))
            batch_ctx_embed.append(self.visual_mlp(visual_reconst_topk_bid.mean(0, keepdim=True)))


            batch_pred_similarity.append(atten_bid)
            batch_topk_pred_similarity.append(atten_topk_bid)
            batch_topk_fusion_similarity.append(atten_fusion)

            ### transform boxes for stage-1
            num_phrase_indices = torch.arange(num_phrase).long().to(self.device)
            max_box_ind = atten_bid.detach().cpu().numpy().argmax(1)
            precomp_boxes_delta_max = pred_targets_bid[num_phrase_indices, max_box_ind] ## numPhrase*4

            max_topk_id = torch.topk(atten_topk_bid, dim=1, k=1)[1].long().squeeze(1)
            precomp_boxes_delta_max_topk = pred_targets_topk_bid[num_phrase_indices, max_topk_id]  ## num_phrase*4
            precomp_boxes_topk_bid_tensor = precomp_boxes_topk_bid.tensor.reshape(-1, self.s2_topk, 4)

            max_fusion_topk_id = torch.topk(atten_fusion, dim=1, k=1)[1].long().squeeze()
            precomp_boxes_delta_max_topk_fusion = pred_targets_topk_bid[num_phrase_indices, max_fusion_topk_id]  ## num_phrase*4

            phr_index = torch.arange(num_phrase).to(self.device) * self.s2_topk

            if self.storage.iter <= cfg.SOLVER.REG_START_ITER:
                max_select_boxes = precomp_boxes_bid[max_box_ind]
                max_precomp_boxes = precomp_boxes_topk_bid[max_topk_id + phr_index]
                max_fusion_precomp_boxes = precomp_boxes_topk_bid[max_fusion_topk_id + phr_index]
            else:
                max_select_boxes = Boxes(self.box2box_translation.apply_deltas(precomp_boxes_delta_max, precomp_boxes_bid[max_box_ind].tensor), precomp_boxes_bid.size)
                max_precomp_boxes = Boxes(self.box2box_translation.apply_deltas(precomp_boxes_delta_max_topk, precomp_boxes_topk_bid_tensor[num_phrase_indices, max_topk_id]), precomp_boxes_bid.size)
                max_fusion_precomp_boxes = Boxes(self.box2box_translation.apply_deltas(precomp_boxes_delta_max_topk_fusion, precomp_boxes_topk_bid_tensor[num_phrase_indices, max_fusion_topk_id]), precomp_boxes_bid.size)

            batch_pred_boxes.append(max_select_boxes)
            batch_topk_pred_boxes.append(max_precomp_boxes)
            batch_topk_fusion_pred_boxes.append(max_fusion_precomp_boxes)


        batch_ctx_sim, batch_ctx_sim_s1 = self.generate_image_sent_discriminative(batch_sent_embed, batch_ctx_embed, batch_ctx_s1_embed)

        noun_reconst_loss, noun_topk_reconst_loss, disc_img_sent_loss_s1, disc_img_sent_loss_s2,  reg_loss, \
        reg_loss_s1 = self.VGLoss(batch_phrase_mask, batch_decode_logits, batch_topk_decoder_logits, batch_phrase_dec_ids,
                                  batch_ctx_sim, batch_ctx_sim_s1, batch_pred_similarity, batch_topk_pred_similarity, batch_boxes_targets, batch_precomp_boxes,
                                  batch_pred_targets, batch_topk_pred_targets,
                                  batch_topk_precomp_boxes)

        all_loss = dict(noun_reconst_loss=noun_reconst_loss, noun_topk_reconst_loss=noun_topk_reconst_loss, disc_img_sent_loss_s1=disc_img_sent_loss_s1,
                        disc_img_sent_loss_s2=disc_img_sent_loss_s2, reg_loss_s1=reg_loss, reg_loss_s2=reg_loss_s1)


        if self.training:
            return all_loss, None
        else:
            return all_loss, (batch_phrase_ids, batch_phrase_types, move2cpu(batch_pred_boxes), move2cpu(batch_pred_similarity),
                              move2cpu(batch_boxes_targets), move2cpu(batch_precomp_boxes), image_unique_id, move2cpu(batch_topk_pred_similarity),
                              move2cpu(batch_topk_fusion_similarity), move2cpu(batch_topk_pred_boxes), move2cpu(batch_topk_fusion_pred_boxes),
                              move2cpu(batch_topk_precomp_boxes), move2cpu(batch_topk_pred_targets), move2cpu(batch_pred_targets))
Exemplo n.º 2
0
    def forward(self, features, all_phrase_ids, targets, precomp_boxes, precomp_score,
                precomp_det_label, image_scale, all_sent_sgs, all_sentences, image_unique_id, det_label_embedding):

        """
        :param obj_proposals: proposal from each images
        :param features: features maps from the backbone
        :param target: gt relation labels
        :param object_vocab, object_vocab_len [[xxx,xxx],[xxx],[xxx]], [2,1,1]
        :param sent_sg: sentence scene graph
        :return: prediction, loss

        note that first dimension is images
        """
        img_num_per_gpu = len(features)
        batch_decode_logits = []
        batch_topk_decoder_logits = []
        batch_pred_similarity = []
        batch_precomp_boxes = []
        batch_pred_boxes = []
        batch_topk_pred_boxes = []
        batch_topk_fusion_pred_boxes = []
        batch_topk_pred_similarity = []
        batch_topk_fusion_similarity = []
        batch_boxes_targets = []
        batch_ctx_embed = []
        batch_ctx_s1_embed = []


        """ Language Embedding"""
        batch_phrase_ids, batch_phrase_types, batch_phrase_embed, batch_phrase_len, \
        batch_phrase_dec_ids, batch_phrase_mask, batch_decoder_word_embed, batch_phrase_glove_embed, batch_rel_phrase_embed, batch_relation_conn, batch_sent_embed,\
        batch_decoder_rel_word_embed, batch_rel_mask, batch_rel_dec_idx = self.phrase_embed(all_sentences, all_phrase_ids, all_sent_sgs)

        h, w = features.shape[-2:]

        for bid in range(img_num_per_gpu):

            """ Visual Embedding """
            precomp_boxes_bid = precomp_boxes[bid].to(self.device)  ## 100*4

            order = []
            for phr_ids in batch_phrase_ids[bid]:
                order.append(all_phrase_ids[bid].index(phr_ids))
            target_filter = targets[bid][np.array(order)]
            batch_boxes_targets.append(target_filter.to(self.device))

            batch_precomp_boxes.append(precomp_boxes_bid)

            img_feat_bid = features[[bid]]

            visual_features_bid = self.rcnn_top(self.det_roi_pooler([img_feat_bid], [precomp_boxes_bid])).mean(dim=[2, 3]).contiguous()

            if cfg.MODEL.VG.SPATIAL_FEAT:
                spa_feat = meshgrid_generation(h, w)
                spa_feat = self.det_roi_pooler([spa_feat], [precomp_boxes_bid]).view(visual_features_bid.shape[0], -1)
                spa_feat = self.spatial_trans(spa_feat)
                visual_features_bid = torch.cat((visual_features_bid, spa_feat), dim=1)


            visual_features_bid = self.visual_embedding(visual_features_bid)
            visual_features_bid = self.vis_batchnorm(visual_features_bid)

            """ Noun Phrase embedding """
            phrase_embed_bid = batch_phrase_embed[bid]
            if phrase_embed_bid.shape[0] == 1 and self.training:
                phrase_embed_bid = self.phr_batchnorm(phrase_embed_bid.repeat(2,1))[[0]]
            else:
                phrase_embed_bid = self.phr_batchnorm(phrase_embed_bid)


            """ Similarity and attention prediction """
            num_box = precomp_boxes_bid.tensor.size(0)
            num_phrase = phrase_embed_bid.size(0)
            phr_inds, obj_inds = self.make_pair(num_phrase, num_box)
            pred_similarity_bid = self.similarity(visual_features_bid, phrase_embed_bid, obj_inds, phr_inds)
            pred_similarity_bid = pred_similarity_bid.reshape(num_phrase, num_box)


            if cfg.MODEL.VG.USING_DET_KNOWLEDGE:
                det_label_embedding_bid = det_label_embedding[bid].to(self.device)
                sim = self.cal_det_label_sim_max(det_label_embedding_bid, batch_phrase_glove_embed[bid], precomp_score[bid])
                pred_similarity_bid = pred_similarity_bid * sim
                sim_mask = (sim > 0).float()
                atten_bid = numerical_stability_masked_softmax(pred_similarity_bid, sim_mask, dim=1)
            else:
                atten_bid = F.softmax(pred_similarity_bid, dim=1)

            ## reconstruction visual features
            visual_reconst_bid = torch.mm(atten_bid, visual_features_bid)
            decode_phr_logits = self.phrase_decoder(visual_reconst_bid, batch_decoder_word_embed[bid])
            batch_decode_logits.append(decode_phr_logits)

            atten_score_topk, atten_ranking_topk = torch.topk(atten_bid, dim=1, k=self.s2_topk) ## (N, 10)
            ind_phr_topk = np.arange(num_phrase).repeat(self.s2_topk)

            visual_features_topk_bid = visual_features_bid[atten_ranking_topk.reshape(-1)] ## (N*10, 1024)
            pred_similarity_topk_bid = self.similarity_topk(visual_features_topk_bid, phrase_embed_bid, ind_phr_topk)
            pred_similarity_topk_bid = pred_similarity_topk_bid.reshape(num_phrase, self.s2_topk)


            if cfg.MODEL.VG.USING_DET_KNOWLEDGE:
                sim_topk = torch.gather(sim, dim=1, index=atten_ranking_topk.long())
                sim_mask = (sim_topk>0).float()
                pred_similarity_topk_bid = pred_similarity_topk_bid * sim_topk
                atten_topk_bid = numerical_stability_masked_softmax(pred_similarity_topk_bid, sim_mask, dim=1)
            else:
                atten_topk_bid = F.softmax(pred_similarity_topk_bid, dim=1)

            atten_fusion = atten_topk_bid * atten_score_topk  ## N*10
            visual_features_topk_bid = visual_features_topk_bid.view(num_phrase, self.s2_topk, -1)
            visual_reconst_topk_bid = (atten_fusion.unsqueeze(2)*visual_features_topk_bid).sum(1) # N*1024
            decoder_phr_topk_logits = self.phrase_decoder(visual_reconst_topk_bid, batch_decoder_word_embed[bid])
            batch_topk_decoder_logits.append(decoder_phr_topk_logits)


            batch_ctx_s1_embed.append(self.visual_mlp(visual_reconst_bid.mean(0, keepdim=True)))
            batch_ctx_embed.append(self.visual_mlp(visual_reconst_topk_bid.mean(0, keepdim=True)))

            batch_pred_similarity.append(atten_bid)
            batch_topk_pred_similarity.append(atten_topk_bid)
            batch_topk_fusion_similarity.append(atten_fusion)

            max_box_ind = atten_bid.detach().cpu().numpy().argmax(1)
            batch_pred_boxes.append(precomp_boxes_bid[max_box_ind])

            max_topk_id = torch.topk(atten_topk_bid, dim=1, k=1)[1].long()
            max_topk_box_ind = torch.gather(atten_ranking_topk, dim=1, index=max_topk_id).squeeze(1).cpu().numpy()
            batch_topk_pred_boxes.append(precomp_boxes_bid[max_topk_box_ind])

            max_fusion_topk_id = torch.topk(atten_fusion, dim=1, k=1)[1].long()
            max_fusion_box_ind = torch.gather(atten_ranking_topk, dim=1, index=max_fusion_topk_id).squeeze(1).cpu().numpy()
            batch_topk_fusion_pred_boxes.append(precomp_boxes_bid[max_fusion_box_ind])


        batch_sent_embed = torch.cat(batch_sent_embed, dim=0)  ## b*1024
        batch_sent_embed = self.sent_mlp(batch_sent_embed) ## N*512
        batch_ctx_embed = torch.cat(batch_ctx_embed, dim=0)  ## N*512
        batch_ctx_s1_embed = torch.cat(batch_ctx_s1_embed, dim=0)  ## N*512
        batch_ctx_sim = torch.mm(batch_ctx_embed, batch_sent_embed.permute(1, 0)) / 512 ** 0.5
        batch_ctx_sim_s1 = torch.mm(batch_ctx_s1_embed, batch_sent_embed.permute(1, 0)) / 512**0.5


        noun_reconst_loss, noun_topk_reconst_loss, disc_img_sent_loss_s1, \
        disc_img_sent_loss_s2 = self.VGLoss(batch_phrase_mask, batch_decode_logits, batch_topk_decoder_logits, batch_phrase_dec_ids,
                                            batch_ctx_sim, batch_ctx_sim_s1, batch_pred_similarity, batch_topk_pred_similarity, batch_boxes_targets, batch_precomp_boxes)

        all_loss = dict(noun_reconst_loss=noun_reconst_loss, noun_topk_reconst_loss=noun_topk_reconst_loss, disc_img_sent_loss_s1=disc_img_sent_loss_s1, disc_img_sent_loss_s2=disc_img_sent_loss_s2)

        if self.training:
            return all_loss, None
        else:
            return all_loss, (batch_phrase_ids, batch_phrase_types, move2cpu(batch_pred_boxes), move2cpu(batch_pred_similarity), move2cpu(batch_boxes_targets),
                              move2cpu(batch_precomp_boxes), image_unique_id, move2cpu(batch_topk_pred_similarity),
                              move2cpu(batch_topk_fusion_similarity), move2cpu(batch_topk_pred_boxes), move2cpu(batch_topk_fusion_pred_boxes))
Exemplo n.º 3
0
    def forward(self, features, all_phrase_ids, targets, precomp_boxes,
                precomp_score, precomp_det_label, image_scale, all_sent_sgs,
                all_sentences, image_unique_id, det_label_embedding):
        """
        :param obj_proposals: proposal from each images
        :param features: features maps from the backbone
        :param target: gt relation labels
        :param object_vocab, object_vocab_len [[xxx,xxx],[xxx],[xxx]], [2,1,1]
        :param sent_sg: sentence scene graph
        :return: prediction, loss

        note that first dimension is images
        """
        img_num_per_gpu = len(features)

        batch_decode_logits = []
        batch_pred_similarity = []
        batch_precomp_boxes = []
        batch_pred_boxes = []
        batch_boxes_targets = []
        """ Language Embedding"""
        batch_phrase_ids, batch_phrase_types, batch_phrase_embed, batch_phrase_len, \
        batch_phrase_dec_ids, batch_phrase_mask, batch_decoder_word_embed, batch_phrase_glove_embed, batch_cst_qid, batch_max_len = \
            self.phrase_embed(all_sentences, all_phrase_ids, all_sent_sgs)

        h, w = features.shape[-2:]

        for bid in range(img_num_per_gpu):
            """ Visual Embedding """
            precomp_boxes_bid = precomp_boxes[bid].to(self.device)  ## 100*4

            order = []
            for phr_ids in batch_phrase_ids[bid]:
                order.append(all_phrase_ids[bid].index(phr_ids))
            target_filter = targets[bid][np.array(order)]
            batch_boxes_targets.append(target_filter.to(self.device))

            batch_precomp_boxes.append(precomp_boxes_bid)

            img_feat_bid = features[[bid]]
            visual_features_bid = self.rcnn_top(
                self.det_roi_pooler(
                    tuple([img_feat_bid]),
                    [precomp_boxes_bid])).mean(dim=[2, 3]).contiguous()

            if cfg.MODEL.VG.SPATIAL_FEAT:
                spa_feat = meshgrid_generation(h, w)
                spa_feat = self.det_roi_pooler(
                    tuple([spa_feat]),
                    [precomp_boxes_bid]).view(visual_features_bid.shape[0], -1)
                spa_feat = self.spatial_trans(spa_feat)
                visual_features_bid = torch.cat(
                    (visual_features_bid, spa_feat), dim=1)

            visual_features_bid = self.visual_embedding(visual_features_bid)
            visual_features_bid = self.vis_batchnorm(visual_features_bid)
            """ Noun Phrase embedding """
            phrase_embed_bid = batch_phrase_embed[bid]
            if phrase_embed_bid.shape[0] == 1 and self.training:
                phrase_embed_bid = self.phr_batchnorm(
                    phrase_embed_bid.repeat(2, 1))[[0]]
            else:
                phrase_embed_bid = self.phr_batchnorm(phrase_embed_bid)
            """ Similarity and attention prediction """
            # num_box = precomp_boxes_bid.tensor.size(0)
            # num_phrase = phrase_embed_bid.size(0)
            # phr_inds, obj_inds = self.make_pair(num_phrase, num_box)

            # pred_similarity_bid = self.similarity(visual_features_bid, phrase_embed_bid, obj_inds, phr_inds)
            pred_similarity_bid = torch.mm(
                phrase_embed_bid, visual_features_bid.permute(
                    1, 0)) / self.phrase_embed_dim**0.5
            # pred_similarity_bid = pred_similarity_bid.reshape(num_phrase, num_box)

            if cfg.MODEL.VG.USING_DET_KNOWLEDGE:
                det_label_embedding_bid = det_label_embedding[bid].to(
                    self.device)
                sim = self.cal_det_label_sim(det_label_embedding_bid,
                                             batch_phrase_glove_embed[bid],
                                             precomp_score[bid])
                pred_similarity_bid = pred_similarity_bid * sim
                sim_mask = (sim > 0).float()
                atten_bid = numerical_stability_masked_softmax(
                    pred_similarity_bid, sim_mask, dim=1)
            else:
                atten_bid = F.softmax(pred_similarity_bid, dim=1)

            ## reconstruction visual features
            visual_reconst_bid = torch.mm(atten_bid, visual_features_bid)

            decode_phr_logits, decode_phr_logits_cst = self.phrase_decoder(
                visual_reconst_bid, batch_decoder_word_embed[bid], cst_phr,
                num_cst)
            batch_decode_logits.append(decode_phr_logits)

            if not self.training:
                batch_pred_similarity.append(atten_bid)
                max_box_ind = atten_bid.detach().cpu().numpy().argmax(1)
                batch_pred_boxes.append(precomp_boxes_bid[max_box_ind])

        noun_reconst_loss, noun_cst_loss = self.VGLoss(
            batch_phrase_mask, batch_decode_logits, batch_phrase_dec_ids,
            batch_cst_mask, batch_cst_decoder_logits, batch_cst_dec_ids)
        all_loss = dict(noun_reconst_loss=noun_reconst_loss,
                        noun_cst_loss=noun_cst_loss)

        if self.training:
            return all_loss, None
        else:
            return all_loss, (batch_phrase_ids, batch_phrase_types,
                              batch_pred_boxes, batch_pred_similarity,
                              batch_boxes_targets, batch_precomp_boxes,
                              image_unique_id)