示例#1
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

            if test:
            prob dists, boxes, img inds, maxscores, classes

        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        # rel_feat = self.relationship_feat.feature_map(x)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        spt_feats = self.get_boxes_encode(boxes, rel_inds)
        pair_inds = self.union_pairs(im_inds)

        if self.hook_for_grad:
            rel_inds = gt_rels[:, :-1].data

        if self.hook_for_grad:
            fmap = result.fmap
            fmap.register_hook(self.save_grad)
        else:
            fmap = result.fmap.detach()

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(fmap, rois)
        # result.obj_dists_head = self.obj_classify_head(obj_fmap_rel)

        obj_embed = F.softmax(result.rm_obj_dists,
                              dim=1) @ self.obj_embed.weight
        obj_embed_lstm = F.softmax(result.rm_obj_dists,
                                   dim=1) @ self.embeddings4lstm.weight
        pos_embed = self.pos_embed(Variable(center_size(boxes.data)))
        obj_pre_rep = torch.cat((result.obj_fmap, obj_embed, pos_embed), 1)
        obj_feats = self.merge_obj_feats(obj_pre_rep)
        # obj_feats=self.trans(obj_feats)
        obj_feats_lstm = torch.cat(
            (obj_feats, obj_embed_lstm),
            -1).contiguous().view(1, obj_feats.size(0), -1)

        # obj_feats = F.relu(obj_feats)

        phr_ori = self.visual_rep(fmap, rois, pair_inds[:, 1:])
        vr_indices = torch.from_numpy(
            intersect_2d(rel_inds[:, 1:].cpu().numpy(),
                         pair_inds[:, 1:].cpu().numpy()).astype(
                             np.uint8)).cuda().max(-1)[1]
        vr = phr_ori[vr_indices]

        phr_feats_high = self.get_phr_feats(phr_ori)

        obj_feats_lstm_output, (obj_hidden_states,
                                obj_cell_states) = self.lstm(obj_feats_lstm)

        rm_obj_dists1 = result.rm_obj_dists + self.context.decoder_lin(
            obj_feats_lstm_output.squeeze())
        obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \
                            phr_feats_high, im_inds, pair_inds)

        obj_embed_lstm1 = F.softmax(rm_obj_dists1,
                                    dim=1) @ self.embeddings4lstm.weight

        obj_feats_lstm1 = torch.cat((obj_feats_output, obj_embed_lstm1), -1).contiguous().view(1, \
                            obj_feats_output.size(0), -1)
        obj_feats_lstm_output, _ = self.lstm(
            obj_feats_lstm1, (obj_hidden_states, obj_cell_states))

        rm_obj_dists2 = rm_obj_dists1 + self.context.decoder_lin(
            obj_feats_lstm_output.squeeze())
        obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \
                            phr_feats_high, im_inds, pair_inds)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds = self.context(
            rm_obj_dists2, obj_feats_output, result.rm_obj_labels
            if self.training or self.mode == 'predcls' else None, boxes.data,
            result.boxes_all)

        obj_dtype = result.obj_fmap.data.type()
        obj_preds_embeds = torch.index_select(self.ort_embedding, 0,
                                              result.obj_preds).type(obj_dtype)
        tranfered_boxes = torch.stack(
            (boxes[:, 0] / IM_SCALE, boxes[:, 3] / IM_SCALE,
             boxes[:, 2] / IM_SCALE, boxes[:, 1] / IM_SCALE,
             ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])) /
             (IM_SCALE**2)), -1).type(obj_dtype)
        obj_features = torch.cat(
            (result.obj_fmap, obj_preds_embeds, tranfered_boxes), -1)
        obj_features_merge = self.merge_obj_low(
            obj_features) + self.merge_obj_high(obj_feats_output)

        # Split into subject and object representations
        result.subj_rep = self.post_emb_s(obj_features_merge)[rel_inds[:, 1]]
        result.obj_rep = self.post_emb_o(obj_features_merge)[rel_inds[:, 2]]
        prod_rep = result.subj_rep * result.obj_rep

        # obj_pools = self.visual_obj(result.fmap.detach(), rois, rel_inds[:, 1:])
        # rel_pools = self.relationship_feat.union_rel_pooling(rel_feat, rois, rel_inds[:, 1:])
        # context_pools = torch.cat([obj_pools, rel_pools], 1)
        # merge_pool = self.merge_feat(context_pools)
        # vr = self.roi_fmap(merge_pool)

        # vr = self.rel_refine(vr)

        prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        prod_rep = torch.cat((prod_rep, spt_feats), -1)
        freq_gate = self.freq_gate(prod_rep)
        freq_gate = F.sigmoid(freq_gate)
        result.rel_dists = self.rel_compress(prod_rep)
        # result.rank_factor = self.ranking_module(prod_rep).view(-1)

        if self.use_bias:
            result.rel_dists = result.rel_dists + freq_gate * self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # rel_rep = smooth_one_hot(rel_rep)
        # rank_factor = F.sigmoid(result.rank_factor)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#2
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")
        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors.detach()   

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None
            # rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)
            rel_labels_neg = self.get_neg_examples(result.rel_labels)
            rel_inds_neg = rel_labels_neg[:,:3]

        #torch.cat((result.rel_labels[:,0].contiguous().view(236,1),result.rm_obj_labels[result.rel_labels[:,1]].view(236,1),result.rm_obj_labels[result.rel_labels[:,2]].view(236,1),result.rel_labels[:,3].contiguous().view(236,1)),-1)
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)  #[275,3], [im_inds, box1_inds, box2_inds]

        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois.detach()) # detach: prevent backforward flowing

        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,   # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(),  # .detach:Returns a new Variable, detached from the current graph
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all.detach() if self.mode == 'sgdet' else result.boxes_all)
        

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else: 
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]
     
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  #[384,2,4096]
        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # prod_rep, rel_inds: [275,4096], [275,3]
    

        if self.use_vision: # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) 
            else:
                prod_rep = prod_rep * vr  # [275,4096]
                if self.training:
                    vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                    prod_rep_neg = subj_rep[rel_inds_neg[:, 1]].detach() * obj_rep[rel_inds_neg[:, 2]].detach() * vr_neg 
                    rel_dists_neg = self.rel_compress(prod_rep_neg)
                    

        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.rm_obj_preds[rel_inds[:, 1]],
                result.rm_obj_preds[rel_inds[:, 2]],
            ), 1))


        if self.training:
            judge = result.rel_labels.data[:,3] != 0
            if judge.sum() != 0:  # gt_rel exit in rel_inds
                select_rel_inds = torch.arange(rel_inds.size(0)).view(-1,1).long().cuda()[result.rel_labels.data[:,3] != 0]
                com_rel_inds = rel_inds[select_rel_inds]
                twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                result.obj_scores = F.softmax(result.rm_obj_dists.detach(), dim=1).view(-1)[twod_inds]   # only 1/4 of 384 obj_dists will be updated; because only 1/4 objs's labels are not 0

                # positive overall score
                obj_scores0 = result.obj_scores[com_rel_inds[:,1]]
                obj_scores1 = result.obj_scores[com_rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists[select_rel_inds], dim=1)    # result.rel_dists has grad
                _, pred_classes_argmax = rel_rep.data[:,:].max(1)  # all classes
                max_rel_score = rel_rep.gather(1, Variable(pred_classes_argmax.view(-1,1))).squeeze()  # SqueezeBackward, GatherBackward
                score_list = torch.cat((com_rel_inds[:,0].float().contiguous().view(-1,1), obj_scores0.data.view(-1,1), obj_scores1.data.view(-1,1), max_rel_score.data.view(-1,1)), 1)
                prob_score = max_rel_score * obj_scores0.detach() * obj_scores1.detach()
                #pos_prob[:,1][result.rel_labels.data[:,3] == 0] = 0  # treat most rel_labels as neg because their rel cls is 0 "unknown"  
                
                # negative overall score
                obj_scores0_neg = result.obj_scores[rel_inds_neg[:,1]]
                obj_scores1_neg = result.obj_scores[rel_inds_neg[:,2]]
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)   # rel_dists_neg has grad
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,:].max(1)  # all classes
                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                score_list_neg = torch.cat((rel_inds_neg[:,0].float().contiguous().view(-1,1), obj_scores0_neg.data.view(-1,1), obj_scores1_neg.data.view(-1,1), max_rel_score_neg.data.view(-1,1)), 1)
                prob_score_neg = max_rel_score_neg * obj_scores0_neg.detach() * obj_scores1_neg.detach()

                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                all_rel_inds = torch.cat((result.rel_labels.data[select_rel_inds], rel_labels_neg), 0)  # [#pos_inds+#neg_inds, 4]
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                score_list_all = torch.cat((score_list,score_list_neg), 0) 
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_rel_inds = all_rel_inds[sort_prob_inds]
                sorted_flag = flag[sort_prob_inds].squeeze()  # can be used to check distribution of pos and neg
                sorted_score_list_all = score_list_all[sort_prob_inds]
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable
                
                # positive triplet and score list
                pos_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                pos_trips = torch.cat((pos_sorted_inds[:,0].contiguous().view(-1,1), result.rm_obj_labels.data.view(-1,1)[pos_sorted_inds[:,1]], result.rm_obj_labels.data.view(-1,1)[pos_sorted_inds[:,2]], pos_sorted_inds[:,3].contiguous().view(-1,1)), 1)
                pos_score_list = sorted_score_list_all.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable 

                # negative triplet and score list
                neg_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                neg_trips = torch.cat((neg_sorted_inds[:,0].contiguous().view(-1,1), result.rm_obj_labels.data.view(-1,1)[neg_sorted_inds[:,1]], result.rm_obj_labels.data.view(-1,1)[neg_sorted_inds[:,2]], neg_sorted_inds[:,3].contiguous().view(-1,1)), 1)
                neg_score_list = sorted_score_list_all.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable
                
                
                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.size(0))[:,None].expand_as(torch.Tensor(pos_exp.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                # some variables .register_hook(extract_grad)

                return result

            else:  # no gt_rel in rel_inds
                print("no gt_rel in rel_inds!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                result.obj_scores = F.softmax(result.rm_obj_dists.detach(), dim=1).view(-1)[twod_inds]

                # positive overall score
                obj_scores0 = result.obj_scores[rel_inds[:,1]]
                obj_scores1 = result.obj_scores[rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
                _, pred_classes_argmax = rel_rep.data[:,:].max(1)  # all classes
                max_rel_score = rel_rep.gather(1, Variable(pred_classes_argmax.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                prob_score = max_rel_score * obj_scores0.detach() * obj_scores1.detach()
                #pos_prob[:,1][result.rel_labels.data[:,3] == 0] = 0  # treat most rel_labels as neg because their rel cls is 0 "unknown"  
                
                # negative overall score
                obj_scores0_neg = result.obj_scores[rel_inds_neg[:,1]]
                obj_scores1_neg = result.obj_scores[rel_inds_neg[:,2]]
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)   
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,:].max(1)  # all classes
                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                prob_score_neg = max_rel_score_neg * obj_scores0_neg.detach() * obj_scores1_neg.detach()

                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                all_rel_inds = torch.cat((result.rel_labels.data, rel_labels_neg), 0)  # [#pos_inds+#neg_inds, 4]
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_rel_inds = all_rel_inds[sort_prob_inds]
                sorted_flag = flag[sort_prob_inds].squeeze()  # can be used to check distribution of pos and neg
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable

                pos_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                neg_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable  
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable

                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.data.size(0))[:,None].expand_as(torch.Tensor(pos_exp.data.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.data.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())

                return result
        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
        
        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores,
                           result.rm_obj_preds, rel_inds[:, 1:],
                           rel_rep)
示例#3
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                *args):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        with torch.no_grad():  # do not update anything in the detector

            targets, x_lst, original_image_sizes = [], [], []
            device = self.rel_fc.weight.get_device(
            ) if self.rel_fc.weight.is_cuda else 'cpu'
            gt_boxes = gt_boxes.to(device)
            gt_classes = gt_classes.to(device)
            gt_rels = gt_rels.to(device)
            for i, s, e in enumerate_by_image(gt_classes[:, 0].long().data):
                targets.append({
                    'boxes': copy.deepcopy(gt_boxes[s:e]),
                    'labels': gt_classes[s:e, 1].long().to(device)
                })
                x_lst.append(x[i].to(device).squeeze())
                original_image_sizes.append(x[i].shape[-2:])

            images, targets = self.detector.transform(x_lst, targets)
            fmap_multiscale = self.detector.backbone(images.tensors)
            if self.mode != 'sgdet':
                rois, obj_labels, bbox_targets, rpn_scores, rpn_box_deltas, rel_labels = \
                    self.gt_boxes(None, im_sizes, image_offset, self.RELS_PER_IMG, gt_boxes,
                                   gt_classes, gt_rels, None, proposals=None,
                                   sample_factor=-1)
                rm_box_priors, rm_box_priors_org = [], []
                for i, s, e in enumerate_by_image(gt_classes[:,
                                                             0].long().data):
                    rm_box_priors.append(targets[i]['boxes'])
                    rm_box_priors_org.append(gt_boxes[s:e])

                result = Result(od_box_targets=bbox_targets,
                                rm_box_targets=bbox_targets,
                                od_obj_labels=obj_labels,
                                rm_box_priors=torch.cat(rm_box_priors),
                                rm_obj_labels=obj_labels,
                                rpn_scores=rpn_scores,
                                rpn_box_deltas=rpn_box_deltas,
                                rel_labels=rel_labels,
                                im_inds=rois[:, 0].long().contiguous() +
                                image_offset)
                result.rm_box_priors_org = torch.cat(rm_box_priors_org)

            else:

                if isinstance(fmap_multiscale, torch.Tensor):
                    fmap_multiscale = OrderedDict([(0, fmap_multiscale)])
                proposals, _ = self.detector.rpn(images, fmap_multiscale,
                                                 targets)
                detections, _ = self.detector.roi_heads(
                    fmap_multiscale, proposals, images.image_sizes, targets)
                boxes = copy.deepcopy(detections)
                boxes_all_dict = self.detector.transform.postprocess(
                    detections, images.image_sizes, original_image_sizes)
                rm_box_priors, rm_box_priors_org, im_inds, obj_labels = [], [], [], []
                for i in range(len(proposals)):
                    rm_box_priors.append(boxes[i]['boxes'])
                    rm_box_priors_org.append(boxes_all_dict[i]['boxes'])
                    obj_labels.append(boxes_all_dict[i]['labels'])
                    im_inds.append(
                        torch.zeros(len(detections[i]['boxes']),
                                    device=device).float() + i)
                im_inds = torch.cat(im_inds).view(-1, 1)

                result = Result(rm_obj_labels=torch.cat(obj_labels).view(-1),
                                rm_box_priors=torch.cat(rm_box_priors),
                                rel_labels=None,
                                im_inds=im_inds.view(-1).long().contiguous() +
                                image_offset)
                result.rm_box_priors_org = torch.cat(rm_box_priors_org)

                if len(result.rm_box_priors) <= 1:
                    raise ValueError(
                        'at least two objects must be detected to build relationships'
                    )

        if result.is_none():
            return ValueError("heck")

        if self.detector_model == 'baseline':
            if self.slim > 0:
                result.fmap = self.fmap_reduce(result.fmap.detach())
            else:
                result.fmap = result.fmap.detach()

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if not hasattr(result, 'rel_labels'):
            result.rel_labels = None

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(
            result.rel_labels if self.training else None, im_inds, boxes)
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        union_rois = torch.cat((
            rois[:, 0][rel_inds[:, 1]][:, None],
            torch.min(rois[:, 1:3][rel_inds[:, 1]], rois[:, 1:3][rel_inds[:,
                                                                          2]]),
            torch.max(rois[:, 3:5][rel_inds[:, 1]], rois[:, 3:5][rel_inds[:,
                                                                          2]]),
        ), 1)

        node_feat = self.multiscale_roi_pool(fmap_multiscale, rm_box_priors,
                                             images.image_sizes)
        edge_feat = self.multiscale_roi_pool(fmap_multiscale,
                                             convert_roi_to_list(union_rois),
                                             images.image_sizes)

        result.rm_obj_dists, result.rel_dists = self.predict(
            node_feat, edge_feat, rel_inds, rois, images.image_sizes)

        if self.use_bias:

            scores_nz = F.softmax(result.rm_obj_dists, dim=1).data
            scores_nz[:, 0] = 0.0
            _, score_ord = scores_nz[:, 1:].sort(dim=1, descending=True)
            result.obj_preds = score_ord[:, 0] + 1

            if self.mode == 'predcls':
                result.obj_preds = gt_classes.data[:, 1]

            freq_pred = self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))
            # tune the weight for freq_bias
            if self.test_bias:
                result.rel_dists = freq_pred
            else:
                result.rel_dists = result.rel_dists + freq_pred

        if self.training:
            return result

        if self.mode == 'predcls':
            result.obj_scores = result.rm_obj_dists.data.new(
                gt_classes.size(0)).fill_(1)
            result.obj_preds = gt_classes.data[:, 1]
        elif self.mode in ['sgcls', 'sgdet']:
            scores_nz = F.softmax(result.rm_obj_dists, dim=1).data
            scores_nz[:, 0] = 0.0  # does not change actually anything
            result.obj_scores, score_ord = scores_nz[:,
                                                     1:].sort(dim=1,
                                                              descending=True)
            result.obj_preds = score_ord[:, 0] + 1
            result.obj_scores = result.obj_scores[:, 0]
        else:
            raise NotImplementedError(self.mode)

        result.obj_preds = Variable(result.obj_preds)
        result.obj_scores = Variable(result.obj_scores)

        # Boxes will get fixed by filter_dets function.
        if self.detector_model == 'mrcnn':
            bboxes = result.rm_box_priors_org
        else:
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#4
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for Relation detection
        Args:
            x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
            im_sizes: A numpy array of (h, w, scale) for each image.
            image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)

            parameters for training:
            gt_boxes: [num_gt, 4] GT boxes over the batch.
            gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
            gt_rels:
            proposals:
            train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
            return_fmap:

        Returns:
            If train:
                scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            If test:
                prob dists, boxes, img inds, maxscores, classes
        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)

        assert not result.is_none(), 'Empty detection result'

        # image_offset refer to Blob
        # self.batch_size_per_gpu * index
        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors
        obj_scores, box_classes = F.softmax(
            result.rm_obj_dists[:, 1:].contiguous(), dim=1).max(1)
        box_classes += 1

        num_img = im_inds[-1] + 1

        # embed(header='rel_model.py before rel_assignments')
        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'

            # only in sgdet mode

            # shapes:
            # im_inds: (box_num,)
            # boxes: (box_num, 4)
            # rm_obj_labels: (box_num,)
            # gt_boxes: (box_num, 4)
            # gt_classes: (box_num, 2) maybe[im_ind, class_ind]
            # gt_rels: (rel_num, 4)
            # image_offset: integer
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # union boxes feats (NumOfRels, obj_dim)
        union_box_feats = self.visual_rep(result.fmap.detach(), rois,
                                          rel_inds[:, 1:].contiguous())
        # single box feats (NumOfBoxes, feats)
        box_feats = self.obj_feature_map(result.fmap.detach(), rois)
        # box spatial feats (NumOfBox, 4)
        bboxes = Variable(center_size(boxes.data))
        sub_bboxes = bboxes[rel_inds[:, 1].contiguous()]
        obj_bboxes = bboxes[rel_inds[:, 2].contiguous()]

        obj_bboxes[:, :2] = obj_bboxes[:, :2].contiguous(
        ) - sub_bboxes[:, :2].contiguous()  # x-y
        obj_bboxes[:, 2:] = obj_bboxes[:, 2:].contiguous(
        ) / sub_bboxes[:, 2:].contiguous()  # w/h
        obj_bboxes[:, :2] /= sub_bboxes[:, 2:].contiguous()  # x-y/h
        obj_bboxes[:, 2:] = torch.log(obj_bboxes[:,
                                                 2:].contiguous())  # log(w/h)

        bbox_spatial_feats = self.spatial_fc(obj_bboxes)

        box_word = self.classes_word_embedding(box_classes)
        box_pair_word = torch.cat((box_word[rel_inds[:, 1].contiguous()],
                                   box_word[rel_inds[:, 2].contiguous()]), 1)
        box_word_feats = self.word_fc(box_pair_word)

        # (NumOfRels, DIM=)
        box_pair_feats = torch.cat(
            (union_box_feats, bbox_spatial_feats, box_word_feats), 1)

        box_pair_score = self.relpn_fc(box_pair_feats)
        #embed(header='filter_rel_labels')
        if self.training:
            pn_rel_label = list()
            pn_pair_score = list()
            #print(result.rel_labels.shape)
            #print(result.rel_labels[:, 0].contiguous().squeeze())
            for i, s, e in enumerate_by_image(
                    result.rel_labels[:, 0].data.contiguous()):
                im_i_rel_label = result.rel_labels[s:e].contiguous()
                im_i_box_pair_score = box_pair_score[s:e].contiguous()

                im_i_rel_fg_inds = torch.nonzero(
                    im_i_rel_label[:, -1].contiguous()).squeeze()
                im_i_rel_fg_inds = im_i_rel_fg_inds.data.cpu().numpy()
                im_i_fg_sample_num = min(RELEVANT_PER_IM,
                                         im_i_rel_fg_inds.shape[0])
                if im_i_rel_fg_inds.size > 0:
                    im_i_rel_fg_inds = np.random.choice(
                        im_i_rel_fg_inds,
                        size=im_i_fg_sample_num,
                        replace=False)

                im_i_rel_bg_inds = torch.nonzero(
                    im_i_rel_label[:, -1].contiguous() == 0).squeeze()
                im_i_rel_bg_inds = im_i_rel_bg_inds.data.cpu().numpy()
                im_i_bg_sample_num = min(EDGES_PER_IM - im_i_fg_sample_num,
                                         im_i_rel_bg_inds.shape[0])
                if im_i_rel_bg_inds.size > 0:
                    im_i_rel_bg_inds = np.random.choice(
                        im_i_rel_bg_inds,
                        size=im_i_bg_sample_num,
                        replace=False)

                #print('{}/{} fg/bg in image {}'.format(im_i_fg_sample_num, im_i_bg_sample_num, i))
                result.rel_sample_pos = torch.Tensor(
                    [im_i_fg_sample_num]).cuda(im_i_rel_label.get_device())
                result.rel_sample_neg = torch.Tensor(
                    [im_i_bg_sample_num]).cuda(im_i_rel_label.get_device())

                im_i_keep_inds = np.append(im_i_rel_fg_inds, im_i_rel_bg_inds)
                im_i_pair_score = im_i_box_pair_score[
                    im_i_keep_inds.tolist()].contiguous()

                im_i_rel_pn_labels = Variable(
                    torch.zeros(im_i_fg_sample_num + im_i_bg_sample_num).type(
                        torch.LongTensor).cuda(x.get_device()))
                im_i_rel_pn_labels[:im_i_fg_sample_num] = 1

                pn_rel_label.append(im_i_rel_pn_labels)
                pn_pair_score.append(im_i_pair_score)

            result.rel_pn_dists = torch.cat(pn_pair_score, 0)
            result.rel_pn_labels = torch.cat(pn_rel_label, 0)

        box_pair_relevant = F.softmax(box_pair_score, dim=1)
        box_pos_pair_ind = torch.nonzero(box_pair_relevant[:, 1].contiguous(
        ) > box_pair_relevant[:, 0].contiguous()).squeeze()

        if box_pos_pair_ind.data.shape == torch.Size([]):
            return None
        #print('{}/{} trim edges'.format(box_pos_pair_ind.size(0), rel_inds.size(0)))
        result.rel_trim_pos = torch.Tensor([box_pos_pair_ind.size(0)]).cuda(
            box_pos_pair_ind.get_device())
        result.rel_trim_total = torch.Tensor([rel_inds.size(0)
                                              ]).cuda(rel_inds.get_device())

        # filtering relations
        filter_rel_inds = rel_inds[box_pos_pair_ind.data]
        filter_box_pair_feats = box_pair_feats[box_pos_pair_ind.data]
        if self.training:
            filter_rel_labels = result.rel_labels[box_pos_pair_ind.data]
            result.rel_labels = filter_rel_labels

        # message passing between boxes and relations
        #embed(header='mp')
        for _ in range(self.mp_iter_num):
            box_feats = self.message_passing(box_feats, filter_box_pair_feats,
                                             filter_rel_inds)
        box_cls_scores = self.cls_fc(box_feats)
        result.rm_obj_dists = box_cls_scores
        obj_scores, box_classes = F.softmax(box_cls_scores[:, 1:].contiguous(),
                                            dim=1).max(1)
        box_classes += 1  # skip background

        # TODO: add memory module
        # filter_box_pair_feats is to be added to memory
        # fbiilter_box_pair_feats = self.memory_()

        # filter_box_pair_feats is to be added to memory

        # RelationCNN
        filter_box_pair_feats_fc1 = self.relcnn_fc1(filter_box_pair_feats)
        filter_box_pair_score = self.relcnn_fc2(filter_box_pair_feats_fc1)
        if not self.graph_cons:
            filter_box_pair_score = filter_box_pair_score.view(
                -1, 2, self.num_rels)
        result.rel_dists = filter_box_pair_score

        if self.training:
            return result

        pred_scores = F.softmax(result.rel_dists, dim=1)
        """
        filter_dets
        boxes: bbox regression else [num_box, 4]
        obj_scores: [num_box] probabilities for the scores
        obj_classes: [num_box] class labels integer
        rel_inds: [num_rel, 2] TENSOR consisting of (im_ind0, im_ind1)
        pred_scores: [num_rel, num_predicates] including irrelevant class(#relclass + 1)
        """
        return filter_dets(boxes, obj_scores, box_classes,
                           filter_rel_inds[:, 1:].contiguous(), pred_scores)
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")
        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors.detach()   




        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)
            rel_labels_neg = self.get_neg_examples(result.rel_labels)
            rel_inds_neg = rel_labels_neg[:,:3]

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)  #[275,3], [im_inds, box1_inds, box2_inds]
        
        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois.detach()) # detach: prevent backforward flowing

        ############### Box Loss in BiLSTM ################
        #result.lstm_box_deltas = self.bbox_fc(result.rm_obj_fmap).view(-1, len(self.classes), 4)
        ############### Box Loss in BiLSTM ################


        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,   # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(),  # .detach:Returns a new Variable, detached from the current graph
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all.detach() if self.mode == 'sgdet' else result.boxes_all)
        

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else: 
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]
     
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  #[384,2,4096]
        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # prod_rep, rel_inds: [275,4096], [275,3]
    
        obj1 = self.obj1_fc(subj_rep[rel_inds[:, 1]])
        obj1 = obj1.view(obj1.size(0), self.num_classes, self.embdim)  # (275, 151, 10)
        obj2 = self.obj2_fc(obj_rep[rel_inds[:, 2]])
        obj2 = obj2.view(obj2.size(0), self.num_classes, self.embdim)  # (275, 151, 10)

        if self.training:
            prod_rep_neg = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]]
            obj1_neg = self.obj1_fc(subj_rep[rel_inds_neg[:, 1]])
            obj1_neg = obj1_neg.view(obj1_neg.size(0), self.num_classes, self.embdim)  # (275*self.neg_num, 151, 10)
            obj2_neg = self.obj2_fc(obj_rep[rel_inds_neg[:, 2]])
            obj2_neg = obj2_neg.view(obj2_neg.size(0), self.num_classes, self.embdim)  # (275*self.neg_num, 151, 10)

        if self.use_vision: # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            #vr = self.visual_rep(result.fmap.detach(), im_bboxes.detach(), rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) 
            else:
                prod_rep = prod_rep * vr  # [275,4096]
                rel_emb = self.rel_seq(prod_rep)
                rel_emb = rel_emb.view(rel_emb.size(0), self.num_rels, self.embdim)  # (275, 51, 10)
                if self.training:
                    vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                    prod_rep_neg = prod_rep_neg * vr_neg if self.training else None # [275*self.neg_num, 4096]
                    rel_emb_neg = self.rel_seq(prod_rep_neg)
                    

        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.rm_obj_preds[rel_inds[:, 1]],
                result.rm_obj_preds[rel_inds[:, 2]],
            ), 1))


        if self.training:
            # pos_exp: [275, 100] * self.neg_num
            twod_inds1 = arange(rel_inds[:, 1]) * self.num_classes + result.rm_obj_preds.data[rel_inds[:, 1]]
            twod_inds2 = arange(rel_inds[:, 2]) * self.num_classes + result.rm_obj_preds.data[rel_inds[:, 2]]
            rel_type = result.rel_labels[:, 3].data # [275]
            twod_inds_r = arange(rel_type) * self.num_rels + rel_type
            
            twod_inds1 = twod_inds1[:,None].expand_as(torch.Tensor(twod_inds1.size(0), self.neg_num)).contiguous().view(-1)
            twod_inds2 = twod_inds2[:,None].expand_as(torch.Tensor(twod_inds2.size(0), self.neg_num)).contiguous().view(-1)
            twod_inds_r = twod_inds_r[:,None].expand_as(torch.Tensor(twod_inds_r.size(0), self.neg_num)).contiguous().view(-1)
            result.pos = obj1.view(-1,self.embdim)[twod_inds1] + rel_emb.view(-1,self.embdim)[twod_inds_r] - obj2.view(-1,self.embdim)[twod_inds2]

            # neg_exp: [275 * self.neg_num, 100]
            twod_inds1_neg = arange(rel_inds_neg[:, 1]) * self.num_classes + result.rm_obj_preds.data[rel_inds_neg[:, 1]]
            twod_inds2_neg = arange(rel_inds_neg[:, 2]) * self.num_classes + result.rm_obj_preds.data[rel_inds_neg[:, 2]]
            rel_type_neg = rel_labels_neg[:, 3]  # [275 * neg_num]
            twod_inds_r_neg = arange(rel_type_neg) * self.num_rels + rel_type_neg
            result.neg = obj1_neg.view(-1,self.embdim)[twod_inds1_neg] + rel_emb_neg.view(-1,self.embdim)[twod_inds_r_neg] - obj2_neg.view(-1,self.embdim)[twod_inds2_neg]
            
            result.anchor = Variable(torch.zeros(result.pos.size(0), self.embdim).cuda())
            return result
        
        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]

        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores,
                           result.rm_obj_preds, rel_inds[:, 1:],
                           rel_rep, obj1, obj2, rel_emb)
示例#6
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False,
                depth_imgs=None):
        """
        Forward pass for relation detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: a numpy array of (h, w, scale) for each image.
        :param image_offset: oOffset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param gt_rels: [] gt relations
        :param proposals: region proposals retrieved from file
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :param return_fmap: if the object detector must return the extracted feature maps
        :param depth_imgs: depth images [batch_size, 1, IM_SIZE, IM_SIZE]
        """

        # -- Get prior `result` object (instead of calling faster-rcnn-detector)
        result = self.get_prior_results(image_offset, gt_boxes, gt_classes,
                                        gt_rels)

        # -- Get RoI and relations
        rois, rel_inds = self.get_rois_and_rels(result, image_offset, gt_boxes,
                                                gt_classes, gt_rels)

        # -- Determine subject and object indices
        subj_inds = rel_inds[:, 1]
        obj_inds = rel_inds[:, 2]

        # -- Extract features from depth backbone
        depth_features = self.depth_backbone(depth_imgs)

        # -- Prevent the gradients from flowing back to depth backbone (Pre-trained mode)
        if self.pretrained_depth:
            depth_features = depth_features.detach()

        # -- Extract RoI features for relation detection
        depth_rois_features = self.get_roi_features_depth(depth_features, rois)

        # -- Create a pairwise relation vector out of location features
        rel_depth = torch.cat(
            (depth_rois_features[subj_inds], depth_rois_features[obj_inds]), 1)
        rel_depth_fc = self.depth_rel_hlayer(rel_depth)

        # -- Predict relation distances
        result.rel_dists = self.depth_rel_out(rel_depth_fc)

        # --- *** END OF ARCHITECTURE *** ---#

        # -- Prepare object predictions vector (PredCLS)
        # Assuming its predcls
        obj_labels = result.rm_obj_labels if self.training or self.mode == 'predcls' else None
        # One hot vector of objects
        result.rm_obj_dists = Variable(
            to_onehot(obj_labels.data, self.num_classes))
        # Indexed vector
        result.obj_preds = obj_labels if obj_labels is not None else result.rm_obj_dists[:, 1:].max(
            1)[1] + 1

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Boxes will get fixed by filter_dets function.
        bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # Filtering: Subject_Score * Pred_score * Obj_score, sorted and ranked
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#7
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap,
            result.rm_obj_dists.detach(),
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.obj_preds[rel_inds[:, 1]],
                result.obj_preds[rel_inds[:, 2]],
            ), 1))

        if self.training:
            return result

        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep)
示例#8
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        #rcnn_pred = result.rm_obj_dists[:, 1:].max(1)[1] + 1  # +1: because the index is in 150-d but truth is 151-d
        #rcnn_ap = torch.mean((rcnn_pred == result.rm_obj_labels).float().cpu())

        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors.detach()

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'  # sgcls's result.rel_labels is gt and not None
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(
            result.rel_labels, im_inds,
            boxes)  #[275,3], [im_inds, box1_inds, box2_inds]

        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(
            result.fmap.detach(),
            rois.detach())  # detach: prevent backforward flowing

        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,  # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(
            ),  # .detach:Returns a new Variable, detached from the current graph
            im_inds,
            result.rm_obj_labels
            if self.training or self.mode == 'predcls' else None,
            boxes.data,
            result.boxes_all.detach()
            if self.mode == 'sgdet' else result.boxes_all)

        #lstm_ap = torch.mean((result.rm_obj_preds == result.rm_obj_labels).float().cpu())
        #fg_ratio = result.rm_obj_labels.nonzero().size(0) / result.rm_obj_labels.size(0)
        #lst = [rcnn_ap.data.numpy(), lstm_ap.data.numpy(), fg_ratio]
        #a = torch.stack((rel_inds[:64, 1], rel_inds[:64, 2]), 0)
        #b = np.unique(a.cpu().numpy())
        #print(len(b) / 64)

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else:
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2,
                                 self.pooling_dim)  #[384,2,4096]

        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]

        # prod_rep, rel_inds: [275,4096], [275,3]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        if self.use_vision:  # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                prod_rep = prod_rep * vr  # [275,4096]

        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.rm_obj_preds[rel_inds[:, 1]],
                    result.rm_obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data
                           ) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]  # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)  # [275, 51]

        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores, result.rm_obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#9
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")
        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors    # .detach()   

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None
            # rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        #torch.cat((result.rel_labels[:,0].contiguous().view(rel_inds.size(0),1),result.rm_obj_labels[result.rel_labels[:,1]].view(rel_inds.size(0),1),result.rm_obj_labels[result.rel_labels[:,2]].view(rel_inds.size(0),1),result.rel_labels[:,3].contiguous().view(rel_inds.size(0),1)),-1)
        #bbox_overlaps(boxes.data[55:57].contiguous().view(-1,1), boxes.data[8].contiguous().view(-1,1))
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)  #[275,3], [im_inds, box1_inds, box2_inds]
        
        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing

        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,   # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(),  # .detach:Returns a new Variable, detached from the current graph
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all if self.mode == 'sgdet' else result.boxes_all)
        

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else: 
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]
     
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  #[384,2,4096]
        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # prod_rep, rel_inds: [275,4096], [275,3]
    

        if self.use_vision: # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) 
            else:
                prod_rep = prod_rep * vr  # [275,4096]


        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.rm_obj_preds[rel_inds[:, 1]],
                result.rm_obj_preds[rel_inds[:, 2]],
            ), 1))

        # Attention: pos should use rm_obj_labes/rel_labels for obj/rel scores; neg should use rm_obj_preds/max_rel_score for obj/rel scores
        if self.training: 
            judge = result.rel_labels.data[:,3] != 0
            if judge.sum() != 0:  # gt_rel exit in rel_inds
                # positive overall score
                select_rel_inds = torch.arange(rel_inds.size(0)).view(-1,1).long().cuda()[result.rel_labels.data[:,3] != 0]
                com_rel_inds = rel_inds[select_rel_inds]
                twod_inds = arange(result.rm_obj_labels.data) * self.num_classes + result.rm_obj_labels.data  # dist: [-10,10]
                result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # only 1/4 of 384 obj_dists will be updated; because only 1/4 objs's labels are not 0
              
                obj_scores0 = result.obj_scores[com_rel_inds[:,1]]
                obj_scores1 = result.obj_scores[com_rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists[select_rel_inds], dim=1)    # result.rel_dists has grad
                rel_score = rel_rep.gather(1, result.rel_labels[select_rel_inds][:,3].contiguous().view(-1,1)).view(-1)  # not use squeeze(); SqueezeBackward, GatherBackward
                prob_score = rel_score * obj_scores0 * obj_scores1

                # negative overall score
                rel_cands = im_inds.data[:, None] == im_inds.data[None]
                rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0   # self relation = 0
                if self.require_overlap:     
                    rel_cands = rel_cands & (bbox_overlaps(boxes.data, boxes.data) > 0)   # Require overlap for detection
                rel_cands = rel_cands.nonzero()  # [#, 2]
                if rel_cands.dim() == 0:
                    print("rel_cands.dim() == 0!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    rel_cands = im_inds.data.new(1, 2).fill_(0) # shaped: [1,2], [0, 0]
                rel_cands = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) # rel_cands' value should be [0, 384]
                rel_inds_neg = rel_cands

                vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                subj_obj = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]]
                prod_rep_neg =  subj_obj * vr_neg
                rel_dists_neg = self.rel_compress(prod_rep_neg)
                all_rel_rep_neg = F.softmax(rel_dists_neg, dim=1)
                _, pred_classes_argmax_neg = all_rel_rep_neg.data[:,1:].max(1)
                pred_classes_argmax_neg = pred_classes_argmax_neg + 1
                all_rel_pred_neg = torch.cat((rel_inds_neg, pred_classes_argmax_neg.view(-1,1)), 1)
                ind_old = torch.ones(all_rel_pred_neg.size(0)).byte().cuda()
                for i in range(com_rel_inds.size(0)):    # delete those box pair with same rel type as pos triplets
                    ind_i = (all_rel_pred_neg[:,0] == com_rel_inds[i, 0]) & (all_rel_pred_neg[:,1] == com_rel_inds[i, 1]) & (result.rm_obj_preds.data[all_rel_pred_neg[:,1]] == result.rm_obj_labels.data[com_rel_inds[i, 1]]) & (all_rel_pred_neg[:,2] == com_rel_inds[i, 2]) & (result.rm_obj_preds.data[all_rel_pred_neg[:,2]] == result.rm_obj_labels.data[com_rel_inds[i, 2]]) & (all_rel_pred_neg[:,3] == result.rel_labels.data[select_rel_inds][i,3]) 
                    ind_i = (1 - ind_i).byte()
                    ind_old = ind_i & ind_old

                rel_inds_neg = rel_inds_neg.masked_select(ind_old.view(-1,1).expand(-1,3) == 1).view(-1,3)
                rel_rep_neg = all_rel_rep_neg.masked_select(Variable(ind_old.view(-1,1).expand(-1,51)) == 1).view(-1,51)
                pred_classes_argmax_neg = pred_classes_argmax_neg.view(-1,1)[ind_old.view(-1,1) == 1]
                rel_labels_pred_neg = all_rel_pred_neg.masked_select(ind_old.view(-1,1).expand(-1,4) == 1).view(-1,4)

                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).view(-1)  # not use squeeze()
                twod_inds_neg = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                obj_scores_neg = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds_neg] 
                obj_scores0_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,1]])
                obj_scores1_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,2]])
                all_score_neg = max_rel_score_neg * obj_scores0_neg * obj_scores1_neg
                # delete those triplet whose score is lower than pos triplets
                prob_score_neg = all_score_neg[all_score_neg.data > prob_score.data.min()] if (all_score_neg.data > prob_score.data.min()).sum() != 0 else all_score_neg


                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_flag = flag[sort_prob_inds].view(-1)  # can be used to check distribution of pos and neg
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable
                
                # positive triplet score
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable 
                # negative triplet score
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable

                # determine how many rows will be updated in rel_dists_neg
                pos_repeat = torch.zeros(1, 1)
                neg_repeat = torch.zeros(1, 1)
                for i in range(pos_exp.size(0)):
                    if ( neg_exp.data > pos_exp.data[i] ).sum() != 0:
                        int_part = (neg_exp.data > pos_exp.data[i]).sum()
                        temp_pos_inds = torch.ones(int_part) * i
                        pos_repeat =  torch.cat((pos_repeat, temp_pos_inds.view(-1,1)), 0)
                        temp_neg_inds = torch.arange(int_part)
                        neg_repeat = torch.cat((neg_repeat, temp_neg_inds.view(-1,1)), 0)
                    else:
                        temp_pos_inds = torch.ones(1)* i
                        pos_repeat =  torch.cat((pos_repeat, temp_pos_inds.view(-1,1)), 0)
                        temp_neg_inds = torch.arange(1)
                        neg_repeat = torch.cat((neg_repeat, temp_neg_inds.view(-1,1)), 0)

                """
                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.size(0))[:,None].expand_as(torch.Tensor(pos_exp.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                """
                result.pos = pos_exp[pos_repeat.cuda().long().view(-1)]
                result.neg = neg_exp[neg_repeat.cuda().long().view(-1)]
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                

                result.ratio = torch.ones(3).cuda()
                result.ratio[0] = result.ratio[0] * (sorted_flag.nonzero().min() / (prob_score.size(0) + all_score_neg.size(0)))
                result.ratio[1] = result.ratio[1] * (sorted_flag.nonzero().max() / (prob_score.size(0) + all_score_neg.size(0)))
                result.ratio[2] = result.ratio[2] * (prob_score.size(0) + all_score_neg.size(0))

                return result

            else:  # no gt_rel in rel_inds
                print("no gt_rel in rel_inds!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                ipdb.set_trace()
                # testing triplet proposal
                rel_cands = im_inds.data[:, None] == im_inds.data[None]
                # self relation = 0
                rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0
                # Require overlap for detection
                if self.require_overlap:
                    rel_cands = rel_cands & (bbox_overlaps(boxes.data, boxes.data) > 0)
                rel_cands = rel_cands.nonzero()
                if rel_cands.dim() == 0:
                    print("rel_cands.dim() == 0!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    rel_cands = im_inds.data.new(1, 2).fill_(0)
                rel_cands = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
                rel_labels_neg = rel_cands
                rel_inds_neg = rel_cands

                twod_inds_neg = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                obj_scores_neg = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds_neg]
                vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                subj_obj = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]]
                prod_rep_neg = subj_obj * vr_neg
                rel_dists_neg = self.rel_compress(prod_rep_neg)
                # negative overall score
                obj_scores0_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,1]])
                obj_scores1_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,2]])
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,1:].max(1)
                pred_classes_argmax_neg = pred_classes_argmax_neg + 1

                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).view(-1)  # not use squeeze()
                prob_score_neg = max_rel_score_neg * obj_scores0_neg * obj_scores1_neg

                result.pos = Variable(torch.zeros(prob_score_neg.size(0)).cuda())
                result.neg = prob_score_neg
                result.anchor = Variable(torch.zeros(prob_score_neg.size(0)).cuda())

                result.ratio = torch.ones(3,1).cuda()

                return result
        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
        
        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores,
                           result.rm_obj_preds, rel_inds[:, 1:],
                           rel_rep)
示例#10
0

        
示例#11
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """Forward pass for detection
        Args:
            x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
            im_sizes: A numpy array of (h, w, scale) for each image.
            image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)

            Training parameters:
            gt_boxes: [num_gt, 4] GT boxes over the batch.
            gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
            gt_rels:
            proposals:
            train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
            return_fmap:

        Returns:
            If train:
                scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            If test:
                prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        """
        Results attributes:
            od_obj_dists: digits after score_fc in RCNN
            rm_obj_dists: od_obj_dists after nms
            obj_scores: nmn 
            obj_preds=None, 
            obj_fmap=None,
            od_box_deltas=None, 
            rm_box_deltas=None,
            od_box_targets=None, 
            rm_box_targets=None, 
            od_box_priors: proposal before nms
            rm_box_priors: proposal after nms
            boxes_assigned=None, 
            boxes_all=None, 
            od_obj_labels=None, 
            rm_obj_labels=None,
            rpn_scores=None, 
            rpn_box_deltas=None, 
            rel_labels=None,
            im_inds: image index of every proposals
            fmap=None, 
            rel_dists=None, 
            rel_inds=None, 
            rel_rep=None
            
            one example:
           sgcls task: 
            result.fmap: torch.Size([6, 512, 37, 37])
result.im_inds: torch.Size([44])
result.obj_fmap: torch.Size([44, 4096])
result.od_box_priors: torch.Size([44, 4])
result.od_obj_dists: torch.Size([44, 151])
result.od_obj_labels: torch.Size([44])
result.rel_labels: torch.Size([316, 4])
result.rm_box_priors: torch.Size([44, 4])
result.rm_obj_dists: torch.Size([44, 151])
result.rm_obj_labels: torch.Size([44])
        """
        if result.is_none():
            return ValueError("heck")

        # image_offset refer to Blob
        # self.batch_size_per_gpu * index
        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        #embed(header='rel_model.py before rel_assignments')
        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'

            # only in sgdet mode

            # shapes:
            # im_inds: (box_num,)
            # boxes: (box_num, 4)
            # rm_obj_labels: (box_num,)
            # gt_boxes: (box_num, 4)
            # gt_classes: (box_num, 2) maybe[im_ind, class_ind]
            # gt_rels: (rel_num, 4)
            # image_offset: integer
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)
        #embed(header='rel_model.py after rel_assignments')

        # rel_labels[:, :3] if sgcls
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        # obj_fmap: (NumOfRoI, 4096)
        # RoIAlign
        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]
        # embed(header='rel_model.py prod_rep')

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        #embed(header='rel model return ')
        if self.training:
            # embed(header='rel_model.py before return')
            # what will be useful:
            # rm_obj_dists, rm_obj_labels
            # rel_labels, rel_dists
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        #embed(header='rel_model.py before return')
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#12
0
文件: rel_model3.py 项目: ht014/lsbr
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, node_rep0 = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        edge_rep = node_rep0.repeat(1, 2)
        edge_rep = edge_rep.view(edge_rep.size(0), 2, -1)

        global_feature = self.global_embedding(result.fmap.detach())
        result.global_dists = self.global_logist(global_feature)
        one_hot_multi = torch.zeros(
            (result.global_dists.shape[0], self.num_classes))

        one_hot_multi[im_inds, result.rm_obj_labels] = 1.0
        result.multi_hot = one_hot_multi.float().cuda()

        subj_global_additive_attention = F.relu(
            self.global_sub_additive(edge_rep[:, 0] + global_feature[im_inds]))
        obj_global_additive_attention = F.relu(
            self.global_obj_additive(edge_rep[:, 1] + global_feature[im_inds]))

        subj_rep = edge_rep[:,
                            0] + subj_global_additive_attention * global_feature[
                                im_inds]
        obj_rep = edge_rep[:,
                           1] + obj_global_additive_attention * global_feature[
                               im_inds]

        if self.training:
            self.centroids = self.disc_center.centroids.data

        # if edge_ctx is None:
        #     edge_rep = self.post_emb(result.obj_preds)
        # else:
        #     edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        # edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)
        #
        # subj_rep = edge_rep[:, 0]
        # obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

        prod_rep = prod_rep * vr

        prod_rep = F.tanh(prod_rep)

        logits, self.direct_memory_feature = self.meta_classify(
            prod_rep, self.centroids)
        # result.rel_dists = self.rel_compress(prod_rep)
        result.rel_dists = logits
        result.rel_dists2 = self.direct_memory_feature[-1]
        # result.hallucinate_logits = self.direct_memory_feature[-1]
        if self.training:
            result.center_loss = self.disc_center(
                prod_rep, result.rel_labels[:, -1]) * 0.01

        if self.use_bias:
            result.rel_dists = result.rel_dists + 1.0 * self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#13
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2.0] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2.0] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

            if test:
            prob dists, boxes, img inds, maxscores, classes

        """
        '---------gt_rel process----------'

        batch_size = x.shape[0]

        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")


        # Prevent gradients from flowing back into score_fc from elsewhere   the last 3 is bullshit
        if self.mode == 'sgdet':
            im_inds = result.im_inds - image_offset  # all indices
            boxes = result.od_box_priors  # all boxes
            rois = torch.cat((im_inds[:, None].float(), boxes), 1)
            result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)
            result.rm_obj_dists, result.obj_preds,  im_inds, result.rm_box_priors, result.rm_obj_labels, rois, result.boxes_all, \
                = self.context(
                result.obj_fmap,
                result.rm_obj_dists.detach(),
                im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
                boxes.data, result.boxes_all, batch_size,
                rois, result.od_box_deltas.detach(), im_sizes, image_offset, gt_classes, gt_boxes)

            boxes = result.rm_box_priors
        else:
            #sgcls and predcls
            im_inds = result.im_inds - image_offset
            boxes = result.rm_box_priors
            rois = torch.cat((im_inds[:, None].float(), boxes), 1)
            result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

            result.rm_obj_dists, result.obj_preds = self.context(
                result.obj_fmap,
                result.rm_obj_dists.detach(),
                im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
                boxes.data, result.boxes_all, batch_size,
                None, None, None, None, None, None)

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        # visual part
        obj_pooling = self.obj_avg_pool(result.fmap.detach(), rois).view(-1, 512)
        subj_rep = obj_pooling[rel_inds[:, 1]]
        obj_rep = obj_pooling[rel_inds[:, 2]]
        vr, union_rois = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
        vr = vr.view(-1, 512)
        x_visual = torch.cat((subj_rep, obj_rep, vr), 1)
        # semantic part
        subj_class = result.obj_preds[rel_inds[:, 1]]
        obj_class = result.obj_preds[rel_inds[:, 2]]
        subj_emb = self.obj_embed(subj_class)
        obj_emb = self.obj_embed2(obj_class)
        x_semantic = torch.cat((subj_emb, obj_emb), 1)

        # padding
        perm, inv_perm, ls_transposed = self.sort_rois(rel_inds[:, 0].data, None, union_rois[:, 1:])
        x_visual_rep = x_visual[perm].contiguous()
        x_semantic_rep = x_semantic[perm].contiguous()

        visual_input = PackedSequence(x_visual_rep, torch.tensor(ls_transposed))
        inputs1, lengths1 = pad_packed_sequence(visual_input, batch_first=False)
        semantic_input = PackedSequence(x_semantic_rep, torch.tensor(ls_transposed))
        inputs2, lengths2 = pad_packed_sequence(semantic_input, batch_first=False)


        self.hidden_state_visual = self.init_hidden(batch_size, bidirectional=False)
        self.hidden_state_semantic = self.init_hidden(batch_size, bidirectional=False)

        output1, self.hidden_state_visual = self.lstm_visual(inputs1, self.hidden_state_visual)
        output2, self.hidden_state_semantic = self.lstm_semantic(inputs2, self.hidden_state_semantic)        
        inputs = torch.cat((output1, output2), 2)


        x_fusion = self.odeBlock(inputs, batch_size)
        x_fusion = x_fusion[1]

        x_fusion, _ = pack_padded_sequence(x_fusion, lengths1, batch_first=False)
        x_out = self.fc_predicate(x_fusion)
        result.rel_dists = x_out[inv_perm]  # for evaluation and crossentropy

        if self.training:
            return result

        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep)
示例#14
0
文件: rel_model.py 项目: galsina/lml
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return:
            (scores, boxdeltas, labels, boxes, boxtargets,
                rpnscores, rpnboxes, rellabels)
            (prob dists, boxes, img inds, maxscores, classes)

        """

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        # if self.training and result.rel_labels is None:
        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        result.subj_rep_rel = subj_rep[rel_inds[:, 1]]
        result.obj_rep_rel = obj_rep[rel_inds[:, 2]]
        result.prod_rep = result.subj_rep_rel * result.obj_rep_rel

        if self.use_vision:
            result.vr = self.visual_rep(result.fmap.detach(), rois,
                                        rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                import ipdb
                ipdb.set_trace()
                result.prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                result.prod_rep = result.prod_rep * result.vr

        if self.use_tanh:
            result.prod_rep = F.tanh(result.prod_rep)

        result.rel_dists = self.rel_compress(result.prod_rep)

        result.obj_from = result.obj_preds[rel_inds[:, 1]]
        result.obj_to = result.obj_preds[rel_inds[:, 2]]

        if self.use_bias:
            result.rel_dists = result.rel_dists + \
                self.freq_bias.index_with_labels(
                    torch.stack((result.obj_from, result.obj_to), 1)
                )

        # if self.training:
        #     return result

        idxs, n_objs = np.unique(im_inds.data.cpu().numpy(),
                                 return_counts=True)

        if not self.training:
            assert len(n_objs) == 1
            n_rels = [len(rel_inds)]
        else:
            _, n_rels = np.unique(result.rel_labels.data[:, 0].cpu().numpy(),
                                  return_counts=True)
            n_rels = n_rels.tolist()

        preds = []
        obj_start = 0
        rel_start = 0
        result.obj_scores = []
        result.rel_reps = []
        for idx, n_obj, n_rel in zip(idxs, n_objs, n_rels):
            obj_end = obj_start + n_obj
            rel_end = rel_start + n_rel

            obj_preds_i = result.obj_preds.data[obj_start:obj_end]
            rm_obj_dists_i = result.rm_obj_dists[obj_start:obj_end]

            twod_inds = arange(obj_preds_i) * self.num_classes + obj_preds_i
            result.obj_scores.append(
                F.softmax(rm_obj_dists_i, dim=1).view(-1)[twod_inds])

            # Bbox regression
            if self.mode == 'sgdet':
                import ipdb
                ipdb.set_trace()  # TODO
                bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                    result.boxes_all.size(0), 4)
            else:
                # Boxes will get fixed by filter_dets function.
                bboxes_i = result.rm_box_priors[obj_start:obj_end]

            if self.mode == 'predcls':
                ((rm_obj_dists_i / 1000) + 1) / 2

            rel_dists_i = result.rel_dists[rel_start:rel_end]
            obj_preds_i = result.obj_preds[obj_start:obj_end]
            obj_scores_i = result.obj_scores[-1]
            if self.lml_topk is not None and self.lml_topk:
                if self.lml_softmax:
                    rel_rep = LML(N=self.lml_topk, branch=1000)(F.softmax(
                        rel_dists_i,
                        dim=1)[:, 1:].contiguous().view(-1)).view(n_rel, -1)
                else:
                    rel_rep = LML(N=self.lml_topk, branch=1000)(
                        rel_dists_i[:,
                                    1:].contiguous().view(-1)).view(n_rel, -1)

                rel_rep = torch.cat((
                    Variable(torch.zeros(n_rel, 1).type_as(rel_dists_i.data)),
                    rel_rep,
                ), 1)
            elif (self.entr_topk is not None
                  and self.entr_topk) or self.ml_loss:
                # Hack to ignore the background.
                rel_rep = torch.cat((Variable(
                    -1e10 * torch.ones(n_rel, 1).type_as(rel_dists_i.data)),
                                     rel_dists_i[:, 1:]), 1)
            else:
                rel_rep = F.softmax(rel_dists_i, dim=1)
            result.rel_reps.append(rel_rep)

            rel_inds_i = rel_inds[rel_start:rel_end, 1:].clone()
            rel_inds_i = rel_inds_i - rel_inds_i.min()  # Very hacky fix...

            # For debugging:
            # bboxes_i = bboxes_i.cpu()
            # obj_scores_i = obj_scores_i.cpu()
            # obj_preds_i = obj_preds_i.cpu()
            # rel_inds_i = rel_inds_i.cpu()
            # rel_rep = rel_rep.cpu()

            pred = filter_dets(bboxes_i, obj_scores_i, obj_preds_i, rel_inds_i,
                               rel_rep)
            preds.append(pred)

            obj_start += n_obj
            rel_start += n_rel

        assert obj_start == result.obj_preds.shape[0]
        assert rel_start == rel_inds.shape[0]
        return result, preds
示例#15
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True, num_sample_per_gt=1)
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        visual_rep = self.visual_rep(result.fmap, rois, rel_inds[:, 1:])

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Now do the approximation WHEREVER THERES A VALID RELATIONSHIP.
        result.rm_obj_dists, result.rel_dists = self.message_pass(
            F.relu(self.edge_unary(visual_rep)), self.obj_unary(result.obj_fmap), rel_inds[:, 1:])

        # result.box_deltas_update = box_deltas

        if self.training:
            return result

        # Decode here ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        if self.mode == 'predcls':
            # Hack to get the GT object labels
            result.obj_scores = result.rm_obj_dists.data.new(gt_classes.size(0)).fill_(1)
            result.obj_preds = gt_classes.data[:, 1]
        elif self.mode == 'sgdet':
            order, obj_scores, obj_preds= filter_det(F.softmax(result.rm_obj_dists),
                                                              result.boxes_all,
                                                              start_ind=0,
                                                              max_per_img=100,
                                                              thresh=0.00,
                                                              pre_nms_topn=6000,
                                                              post_nms_topn=300,
                                                              nms_thresh=0.3,
                                                              nms_filter_duplicates=True)
            idx, perm = torch.sort(order)
            result.obj_preds = rel_inds.new(result.rm_obj_dists.size(0)).fill_(1)
            result.obj_scores = result.rm_obj_dists.data.new(result.rm_obj_dists.size(0)).fill_(0)
            result.obj_scores[idx] = obj_scores.data[perm]
            result.obj_preds[idx] = obj_preds.data[perm]
        else:
            scores_nz = F.softmax(result.rm_obj_dists).data
            scores_nz[:, 0] = 0.0
            result.obj_scores, score_ord = scores_nz[:, 1:].sort(dim=1, descending=True)
            result.obj_preds = score_ord[:,0] + 1
            result.obj_scores = result.obj_scores[:,0]

        result.obj_preds = Variable(result.obj_preds)
        result.obj_scores = Variable(result.obj_scores)

        # Set result's bounding boxes to be size
        # [num_boxes, topk, 4] instead of considering every single object assignment.
        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data

        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors
        rel_rep = F.softmax(result.rel_dists)

        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep)
示例#16
0
    def forward(self, batch):
        """
        Forward pass for detection

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param gt_rels: [num_gt_rels, 4] gt relationships where each one is (img_id, subj_id, obj_id, class)

        """

        assert len(batch) == 1, ('single GPU is only supported in this code',
                                 len(batch))

        x, gt_boxes, gt_classes, gt_rels = batch[0][0], batch[0][3], batch[0][
            4], batch[0][5]

        with NO_GRAD():  # do not update anything in the detector
            if self.backbone == 'vgg16_old':
                raise NotImplementedError('%s is not supported any more' %
                                          self.backbone)
            else:
                result = self.faster_rcnn(x, gt_boxes, gt_classes, gt_rels)

        result.fmap = result.fmap.detach()  # do not update the detector

        im_inds = result.im_inds
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                0,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)
        elif not hasattr(result, 'rel_labels'):
            result.rel_labels = None

        rel_inds = self.get_rel_inds(
            result.rel_labels if self.training else None, im_inds, boxes)
        result.rel_inds = rel_inds
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.node_feat, result.edge_feat = self.node_edge_features(
            result.fmap, rois, rel_inds[:, 1:], im_sizes=result.im_sizes)

        result.rm_obj_dists, result.rel_dists = self.predict(
            result.node_feat,
            result.edge_feat,
            rel_inds,
            rois=rois,
            im_sizes=result.im_sizes)

        if self.use_bias:

            scores_nz = F.softmax(result.rm_obj_dists, dim=1).data
            scores_nz[:, 0] = 0.0
            _, score_ord = scores_nz[:, 1:].sort(dim=1, descending=True)
            result.obj_preds = score_ord[:, 0] + 1

            if self.mode == 'predcls':
                result.obj_preds = gt_classes.data[:, 1]

            freq_pred = self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))
            # tune the weight for freq_bias
            if self.test_bias:
                result.rel_dists = freq_pred
            else:
                result.rel_dists = result.rel_dists + freq_pred

        if self.training:
            result.rois = rois
            return result

        if self.mode == 'predcls':
            result.obj_scores = result.rm_obj_dists.data.new(
                gt_classes.shape[0]).fill_(1)
            result.obj_preds = gt_classes.data[:, 1]
        elif self.mode in ['sgcls', 'sgdet']:
            scores_nz = F.softmax(result.rm_obj_dists, dim=1).data
            scores_nz[:, 0] = 0.0  # does not change actually anything
            result.obj_scores, score_ord = scores_nz[:,
                                                     1:].sort(dim=1,
                                                              descending=True)
            result.obj_preds = score_ord[:, 0] + 1
            result.obj_scores = result.obj_scores[:, 0]
        else:
            raise NotImplementedError(self.mode)

        result.obj_preds = Variable(result.obj_preds)
        result.obj_scores = Variable(result.obj_scores)

        # Boxes will get fixed by filter_dets function.
        if self.backbone != 'vgg16_old':
            bboxes = result.rm_box_priors_org
        else:
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#17
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        if self.use_ggnn_obj:
            result.rm_obj_dists = self.ggnn_obj_reason(
                im_inds, result.obj_fmap, result.rm_obj_labels
                if self.training or self.mode == 'predcls' else None)

        vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

        if self.use_ggnn_rel:
            result.rm_obj_dists, result.obj_preds, result.rel_dists = self.ggnn_rel_reason(
                obj_fmaps=result.obj_fmap,
                obj_logits=result.rm_obj_dists,
                vr=vr,
                rel_inds=rel_inds,
                obj_labels=result.rm_obj_labels
                if self.training or self.mode == 'predcls' else None,
                boxes_per_cls=result.boxes_all)
        else:
            result.rm_obj_dists, result.obj_preds, result.rel_dists = self.vr_fc_cls(
                obj_logits=result.rm_obj_dists,
                vr=vr,
                obj_labels=result.rm_obj_labels
                if self.training or self.mode == 'predcls' else None,
                boxes_per_cls=result.boxes_all)

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#18
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for Relation detection
        Args:
            x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
            im_sizes: A numpy array of (h, w, scale) for each image.
            image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)

            parameters for training:
            gt_boxes: [num_gt, 4] GT boxes over the batch.
            gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
            gt_rels:
            proposals:
            train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
            return_fmap:

        Returns:
            If train:
                scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            If test:
                prob dists, boxes, img inds, maxscores, classes
        """
        s_t = time.time()
        verbose = False

        def check(sl, een, sst=s_t):
            if verbose:
                print('{}{}'.format(sl, een - sst))

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        check('detector', tt())

        assert not result.is_none(), 'Empty detection result'

        # image_offset refer to Blob
        # self.batch_size_per_gpu * index
        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors
        obj_scores, box_classes = F.softmax(
            result.rm_obj_dists[:, 1:].contiguous(), dim=1).max(1)
        box_classes += 1
        # TODO: predcls implementation obj_scores and box_classes

        num_img = im_inds[-1] + 1

        # embed(header='rel_model.py before rel_assignments')
        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'

            # only in sgdet mode

            # shapes:
            # im_inds: (box_num,)
            # boxes: (box_num, 4)
            # rm_obj_labels: (box_num,)
            # gt_boxes: (box_num, 4)
            # gt_classes: (box_num, 2) maybe[im_ind, class_ind]
            # gt_rels: (rel_num, 4)
            # image_offset: integer
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # union boxes feats (NumOfRels, obj_dim)
        union_box_feats = self.visual_rep(result.fmap.detach(), rois,
                                          rel_inds[:, 1:].contiguous())
        # single box feats (NumOfBoxes, feats)
        box_feats = self.obj_feature_map(result.fmap.detach(), rois)
        # box spatial feats (NumOfBox, 4)

        box_pair_feats = self.fuse_message(union_box_feats, boxes, box_classes,
                                           rel_inds)
        box_pair_score = self.relpn_fc(box_pair_feats)

        if self.training:
            # sampling pos and neg relations here for training
            rel_sample_pos, rel_sample_neg = 0, 0
            pn_rel_label, pn_pair_score = list(), list()
            for i, s, e in enumerate_by_image(
                    result.rel_labels[:, 0].data.contiguous()):
                im_i_rel_label = result.rel_labels[s:e].contiguous()
                im_i_box_pair_score = box_pair_score[s:e].contiguous()

                im_i_rel_fg_inds = torch.nonzero(
                    im_i_rel_label[:, -1].contiguous()).squeeze()
                im_i_rel_fg_inds = im_i_rel_fg_inds.data.cpu().numpy()
                im_i_fg_sample_num = min(RELEVANT_PER_IM,
                                         im_i_rel_fg_inds.shape[0])
                if im_i_rel_fg_inds.size > 0:
                    im_i_rel_fg_inds = np.random.choice(
                        im_i_rel_fg_inds,
                        size=im_i_fg_sample_num,
                        replace=False)

                im_i_rel_bg_inds = torch.nonzero(
                    im_i_rel_label[:, -1].contiguous() == 0).squeeze()
                im_i_rel_bg_inds = im_i_rel_bg_inds.data.cpu().numpy()
                im_i_bg_sample_num = min(EDGES_PER_IM - im_i_fg_sample_num,
                                         im_i_rel_bg_inds.shape[0])
                if im_i_rel_bg_inds.size > 0:
                    im_i_rel_bg_inds = np.random.choice(
                        im_i_rel_bg_inds,
                        size=im_i_bg_sample_num,
                        replace=False)

                #print('{}/{} fg/bg in image {}'.format(im_i_fg_sample_num, im_i_bg_sample_num, i))
                rel_sample_pos += im_i_fg_sample_num
                rel_sample_neg += im_i_bg_sample_num

                im_i_keep_inds = np.append(im_i_rel_fg_inds, im_i_rel_bg_inds)
                im_i_pair_score = im_i_box_pair_score[
                    im_i_keep_inds.tolist()].contiguous()

                im_i_rel_pn_labels = Variable(
                    torch.zeros(im_i_fg_sample_num + im_i_bg_sample_num).type(
                        torch.LongTensor).cuda(x.get_device()))
                im_i_rel_pn_labels[:im_i_fg_sample_num] = 1

                pn_rel_label.append(im_i_rel_pn_labels)
                pn_pair_score.append(im_i_pair_score)

            result.rel_pn_dists = torch.cat(pn_pair_score, 0)
            result.rel_pn_labels = torch.cat(pn_rel_label, 0)
            result.rel_sample_pos = torch.Tensor([rel_sample_pos]).cuda(
                im_i_rel_label.get_device())
            result.rel_sample_neg = torch.Tensor([rel_sample_neg]).cuda(
                im_i_rel_label.get_device())

        box_pair_relevant = F.softmax(box_pair_score, dim=1)
        box_pos_pair_ind = torch.nonzero(box_pair_relevant[:, 1].contiguous(
        ) > box_pair_relevant[:, 0].contiguous()).squeeze()

        if box_pos_pair_ind.data.shape == torch.Size([]):
            return None
        #print('{}/{} trim edges'.format(box_pos_pair_ind.size(0), rel_inds.size(0)))
        result.rel_trim_pos = torch.Tensor([box_pos_pair_ind.size(0)]).cuda(
            box_pos_pair_ind.get_device())
        result.rel_trim_total = torch.Tensor([rel_inds.size(0)
                                              ]).cuda(rel_inds.get_device())

        if self.trim_graph:
            # filtering relations
            filter_rel_inds = rel_inds[box_pos_pair_ind.data]
            filter_box_pair_feats = box_pair_feats[box_pos_pair_ind.data]
        else:
            filter_rel_inds = rel_inds
            filter_box_pair_feats = box_pair_feats
        if self.training:
            if self.trim_graph:
                filter_rel_labels = result.rel_labels[box_pos_pair_ind.data]
            else:
                filter_rel_labels = result.rel_labels
            num_gt_filtered = torch.nonzero(filter_rel_labels[:, -1])
            if num_gt_filtered.shape == torch.Size([]):
                num_gt_filtered = 0
            else:
                num_gt_filtered = num_gt_filtered.size(0)
            num_gt_orignial = torch.nonzero(result.rel_labels[:, -1]).size(0)
            result.rel_pn_recall = torch.Tensor(
                [num_gt_filtered / num_gt_orignial]).cuda(x.get_device())
            result.rel_labels = filter_rel_labels
        check('trim', tt())

        # message passing between boxes and relations
        if self.mode in ('sgcls', 'sgdet'):
            for _ in range(self.mp_iter_num):
                box_feats = self.message_passing(box_feats,
                                                 filter_box_pair_feats,
                                                 filter_rel_inds)
            box_cls_scores = self.cls_fc(box_feats)
            result.rm_obj_dists = box_cls_scores
            obj_scores, box_classes = F.softmax(
                box_cls_scores[:, 1:].contiguous(), dim=1).max(1)
            box_classes += 1  # skip background
        check('mp', tt())

        # RelationCNN
        filter_box_pair_feats_fc1 = self.relcnn_fc1(filter_box_pair_feats)
        filter_box_pair_score = self.relcnn_fc2(filter_box_pair_feats_fc1)

        result.rel_dists = filter_box_pair_score
        pred_scores_stage_one = F.softmax(result.rel_dists, dim=1).data

        # filter_box_pair_feats is to be added to memory
        if self.training:
            padded_filter_feats, pack_lengths, re_filter_rel_inds, padded_rel_labels = \
                self.pad_sequence(
                    filter_rel_inds,
                    filter_box_pair_feats_fc1,
                    rel_labels=result.rel_labels
                )
        else:
            padded_filter_feats, pack_lengths, re_filter_rel_inds, padded_rel_inds = \
                self.pad_sequence(
                    filter_rel_inds,
                    filter_box_pair_feats_fc1
                )

        # trimming zeros to avoid no rel in image
        trim_pack_lengths = np.trim_zeros(pack_lengths)
        trim_padded_filter_feats = padded_filter_feats[:trim_pack_lengths.
                                                       shape[0]]
        packed_filter_feats = pack_padded_sequence(trim_padded_filter_feats,
                                                   trim_pack_lengths,
                                                   batch_first=True)
        if self.training:
            trim_padded_rel_labels = padded_rel_labels[:trim_pack_lengths.
                                                       shape[0]]
            packed_rel_labels = pack_padded_sequence(trim_padded_rel_labels,
                                                     trim_pack_lengths,
                                                     batch_first=True)
            rel_mem_dists = self.mem_module(inputs=packed_filter_feats,
                                            rel_labels=packed_rel_labels)
            rel_mem_dists = self.re_order_packed_seq(rel_mem_dists,
                                                     filter_rel_inds,
                                                     re_filter_rel_inds)
            result.rel_mem_dists = rel_mem_dists
        else:
            trim_padded_rel_inds = padded_rel_inds[:trim_pack_lengths.shape[0]]
            packed_rel_inds = pack_padded_sequence(trim_padded_rel_inds,
                                                   trim_pack_lengths,
                                                   batch_first=True)
            rel_mem_dists = self.mem_module(inputs=packed_filter_feats,
                                            rel_inds=packed_rel_inds,
                                            obj_classes=box_classes)
            rel_mem_probs = self.re_order_packed_seq(rel_mem_dists,
                                                     filter_rel_inds,
                                                     re_filter_rel_inds)
            rel_mem_probs = rel_mem_probs.data

        check('mem', tt())
        if self.training:
            return result

        # pad stage one output in rel_mem_probs if it sums zero
        for rel_i in range(rel_mem_probs.size(0)):
            rel_i_probs = rel_mem_probs[rel_i]
            if rel_i_probs.sum() == 0:
                rel_mem_probs[rel_i] = pred_scores_stage_one[rel_i]
        """
        filter_dets
        boxes: bbox regression else [num_box, 4]
        obj_scores: [num_box] probabilities for the scores
        obj_classes: [num_box] class labels integer
        rel_inds: [num_rel, 2] TENSOR consisting of (im_ind0, im_ind1)
        pred_scores: [num_rel, num_predicates] including irrelevant class(#relclass + 1)
        """
        check('mem processing', tt())
        return filter_dets(boxes, obj_scores, box_classes,
                           filter_rel_inds[:, 1:].contiguous(), rel_mem_probs)
示例#19
0
文件: rel_model2.py 项目: ht014/lsbr
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        global_feature = self.global_embedding(result.fmap.detach())
        result.global_dists = self.global_logist(global_feature)
        # print(result.global_dists)
        # result.global_rel_dists = F.sigmoid(self.global_rel_logist(global_feature))

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, node_rep0 = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        one_hot_multi = torch.zeros(
            (result.global_dists.shape[0], self.num_classes))

        one_hot_multi[im_inds, result.rm_obj_labels] = 1.0
        result.multi_hot = one_hot_multi.float().cuda()
        edge_rep = node_rep0.repeat(1, 2)

        edge_rep = edge_rep.view(edge_rep.size(0), 2, -1)
        global_feature_re = global_feature[im_inds]
        subj_global_additive_attention = F.relu(
            self.global_sub_additive(edge_rep[:, 0] + global_feature_re))
        obj_global_additive_attention = F.relu(
            torch.sigmoid(
                self.global_obj_additive(edge_rep[:, 1] + global_feature_re)))

        subj_rep = edge_rep[:,
                            0] + subj_global_additive_attention * global_feature_re
        obj_rep = edge_rep[:,
                           1] + obj_global_additive_attention * global_feature_re

        edge_of_coordinate_rep = self.coordinate_feats(boxes.data, rel_inds)

        e_ij_coordinate_rep = self.edge_coordinate_embedding(
            edge_of_coordinate_rep)

        union_rep = self.visual_rep(result.fmap.detach(), rois, rel_inds[:,
                                                                         1:])
        edge_feat_init = union_rep

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[
            rel_inds[:, 2]] * edge_feat_init
        prod_rep = torch.cat((prod_rep, e_ij_coordinate_rep), 1)

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels, fg_rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        #if self.training and (not self.use_rl_tree):
            # generate arbitrary forest according to graph
        #    arbitrary_forest = graph_to_trees(self.co_occour, result.rel_labels, gt_classes)
        #else:
        arbitrary_forest = None

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        if self.use_rl_tree:
            result.rel_label_tkh = self.get_rel_label(im_inds, gt_rels, rel_inds)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # whole image feature, used for virtual node
        batch_size = result.fmap.shape[0]
        image_rois = Variable(torch.randn(batch_size, 5).fill_(0).cuda())
        for i in range(batch_size):
            image_rois[i, 0] = i
            image_rois[i, 1] = 0
            image_rois[i, 2] = 0
            image_rois[i, 3] = IM_SCALE
            image_rois[i, 4] = IM_SCALE
        image_fmap = self.obj_feature_map(result.fmap.detach(), image_rois)

        if self.mode != 'sgdet' and self.training:
            fg_rel_labels = result.rel_labels

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx, result.gen_tree_loss, result.entropy_loss, result.pair_gate, result.pair_gt = self.context(
            result.obj_fmap,
            result.rm_obj_dists.detach(),
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all, 
            arbitrary_forest,
            image_rois,
            image_fmap,
            self.co_occour,
            fg_rel_labels if self.training else None,
            x)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.hidden_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep =  torch.cat((subj_rep[rel_inds[:, 1]], obj_rep[rel_inds[:, 2]]), 1)
        prod_rep = self.post_cat(prod_rep)

        if self.use_encoded_box:
            # encode spatial info
            assert(boxes.shape[1] == 4)
            # encoded_boxes: [box_num, (x1,y1,x2,y2,cx,cy,w,h)]
            encoded_boxes = tree_utils.get_box_info(boxes)
            # encoded_boxes_pair: [batch_szie, (box1, box2, unionbox, intersectionbox)]
            encoded_boxes_pair = tree_utils.get_box_pair_info(encoded_boxes[rel_inds[:, 1]], encoded_boxes[rel_inds[:, 2]])
            # encoded_spatial_rep
            spatial_rep = F.relu(self.encode_spatial_2(F.relu(self.encode_spatial_1(encoded_boxes_pair))))
            # element-wise multiply with prod_rep
            prod_rep = prod_rep * spatial_rep

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.obj_preds[rel_inds[:, 1]],
                result.obj_preds[rel_inds[:, 2]],
            ), 1))

        if self.training and (not self.rl_train):
            return result

        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        if not self.rl_train:
            return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep, gt_boxes, gt_classes, gt_rels)
        else:
            return result, filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep, gt_boxes, gt_classes, gt_rels)
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False,
                depth_imgs=None):
        """
        Forward pass for relation detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: a numpy array of (h, w, scale) for each image.
        :param image_offset: offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param gt_rels: [] gt relations
        :param proposals: region proposals retrieved from file
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :param return_fmap: if the object detector must return the extracted feature maps
        :param depth_imgs: depth images [batch_size, 1, IM_SIZE, IM_SIZE]
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        if self.has_visual:
            # -- Feed forward the rgb images to Faster-RCNN
            result = self.detector(x,
                                   im_sizes,
                                   image_offset,
                                   gt_boxes,
                                   gt_classes,
                                   gt_rels,
                                   proposals,
                                   train_anchor_inds,
                                   return_fmap=True)
        else:
            # -- Get prior `result` object (instead of calling faster-rcnn's detector)
            result = self.get_prior_results(image_offset, gt_boxes, gt_classes,
                                            gt_rels)

        # -- Get RoI and relations
        rois, rel_inds = self.get_rois_and_rels(result, image_offset, gt_boxes,
                                                gt_classes, gt_rels)
        boxes = result.rm_box_priors

        # -- Determine subject and object indices
        subj_inds = rel_inds[:, 1]
        obj_inds = rel_inds[:, 2]

        # -- Prepare object predictions vector (PredCLS)
        # replace with ground truth labels
        result.obj_preds = result.rm_obj_labels
        # replace with one-hot distribution of ground truth labels
        result.rm_obj_dists = F.one_hot(result.rm_obj_labels.data,
                                        self.num_classes).float()
        obj_cls = result.rm_obj_dists
        result.rm_obj_dists = result.rm_obj_dists * 1000 + (
            1 - result.rm_obj_dists) * (-1000)

        rel_features = []
        # -- Extract RGB features
        if self.has_visual:
            # Feed the extracted features from first conv layers to the last 'classifier' layers (VGG)
            # Here, only the last 3 layers of VGG are being trained. Everything else (in self.detector)
            # is frozen.
            result.obj_fmap = self.get_roi_features(result.fmap.detach(), rois)

            # -- Create a pairwise relation vector out of visual features
            rel_visual = torch.cat(
                (result.obj_fmap[subj_inds], result.obj_fmap[obj_inds]), 1)
            rel_visual_fc = self.visual_hlayer(rel_visual)
            rel_visual_scale = self.visual_scale(rel_visual_fc)
            rel_features.append(rel_visual_scale)

        # -- Extract Location features
        if self.has_loc:
            # -- Create a pairwise relation vector out of location features
            rel_location = self.get_loc_features(boxes, subj_inds, obj_inds)
            rel_location_fc = self.location_hlayer(rel_location)
            rel_location_scale = self.location_scale(rel_location_fc)
            rel_features.append(rel_location_scale)

        # -- Extract Class features
        if self.has_class:
            if self.use_embed:
                obj_cls = obj_cls @ self.obj_embed.weight
            # -- Create a pairwise relation vector out of class features
            rel_classme = torch.cat((obj_cls[subj_inds], obj_cls[obj_inds]), 1)
            rel_classme_fc = self.classme_hlayer(rel_classme)
            rel_classme_scale = self.classme_scale(rel_classme_fc)
            rel_features.append(rel_classme_scale)

        # -- Extract Depth features
        if self.has_depth:
            # -- Extract features from depth backbone
            depth_features = self.depth_backbone(depth_imgs)
            depth_rois_features = self.get_roi_features_depth(
                depth_features, rois)

            # -- Create a pairwise relation vector out of location features
            rel_depth = torch.cat((depth_rois_features[subj_inds],
                                   depth_rois_features[obj_inds]), 1)
            rel_depth_fc = self.depth_rel_hlayer(rel_depth)
            rel_depth_scale = self.depth_scale(rel_depth_fc)
            rel_features.append(rel_depth_scale)

        # -- Create concatenated feature vector
        rel_fusion = torch.cat(rel_features, 1)

        # -- Extract relation embeddings (penultimate layer)
        rel_embeddings = self.fusion_hlayer(rel_fusion)

        # -- Mix relation embeddings with UoBB features
        if self.has_visual and self.use_vision:
            uobb_features = self.get_union_features(result.fmap.detach(), rois,
                                                    rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                uobb_limit = int(self.hidden_dim / 2)
                rel_embeddings = torch.cat((rel_embeddings[:, :uobb_limit] *
                                            uobb_features[:, :uobb_limit],
                                            rel_embeddings[:, uobb_limit:]), 1)
            else:
                rel_embeddings = rel_embeddings * uobb_features

        # -- Predict relation distances
        result.rel_dists = self.rel_out(rel_embeddings)

        # -- Frequency bias
        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        # --- *** END OF ARCHITECTURE *** ---#

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # Filtering: Subject_Score * Pred_score * Obj_score, sorted and ranked
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#22
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):

        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors
        # boxes = result.boxes_assigned
        boxes_deltas = result.rm_box_deltas # sgcls is None
        boxes_all = result.boxes_all # sgcls is None

        if (self.training) and (result.rel_labels is None):
            import pdb; pdb.set_trace()
            print('debug')
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, result.rm_obj_dists.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes, result.rm_obj_dists.data)

        reward_rel_inds = None
        if self.mode == 'sgdet':
            msg_rel_inds = self.get_msg_rel_inds(im_inds, boxes, result.rm_obj_dists.data)
            reward_rel_inds = self.get_reward_rel_inds(im_inds, boxes, result.rm_obj_dists.data)


        if self.mode == 'sgdet':
            result.rm_obj_dists_list, result.obj_preds_list, result.rel_dists_list, result.bbox_list, result.offset_list, \
                result.rel_dists, result.obj_preds, result.boxes_all, result.all_rel_logits = self.context(
                                            result.fmap.detach(), result.rm_obj_dists.detach(), im_inds, rel_inds, msg_rel_inds, 
                                            reward_rel_inds, im_sizes, boxes.detach(), boxes_deltas.detach(), boxes_all.detach(),
                                            result.rm_obj_labels if self.training or self.mode == 'predcls' else None)

        elif self.mode in ['sgcls', 'predcls']:
            result.obj_preds, result.rm_obj_logits, result.rel_logits = self.context(
                                            result.fmap.detach(), result.rm_obj_dists.detach(),
                                            im_inds, rel_inds, None, None, im_sizes, boxes.detach(), None, None,
                                            result.rm_obj_labels if self.training or self.mode == 'predcls' else None)
        else:
            raise NotImplementedError

        # result.rm_obj_dists = result.rm_obj_dists_list[-1]

        if self.training:
            return result

        if self.mode == 'predcls':
            import pdb; pdb.set_trace()
            print('debug..')
            result.obj_preds = result.rm_obj_labels
            result.obj_scores = Variable(torch.from_numpy(np.ones(result.obj_preds.shape[0],)).float().cuda())
        else:
            twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
            result.obj_scores = F.softmax(result.rm_obj_logits, dim=1).view(-1)[twod_inds]

        # # Bbox regression
        if self.mode == 'sgdet':
            if conf.use_postprocess:
                bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
            else:
                bboxes = result.rm_box_priors
        else:
            bboxes = result.rm_box_priors

        rel_scores = F.sigmoid(result.rel_logits)

        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_scores)