示例#1
0
    def edge_ctx(self, obj_feats, obj_dists, im_inds, obj_preds, box_priors=None):
        """
        Object context and object classification.
        :param obj_feats: [num_obj, img_dim + object embedding0 dim]
        :param obj_dists: [num_obj, #classes]
        :param im_inds: [num_obj] the indices of the images
        :return: edge_ctx: [num_obj, #feats] For later!
        """

        # Only use hard embeddings
        obj_embed2 = self.obj_embed2(obj_preds)
        # obj_embed3 = F.softmax(obj_dists, dim=1) @ self.obj_embed3.weight
        inp_feats = torch.cat((obj_embed2, obj_feats), 1)

        # Sort by the confidence of the maximum detection.
        confidence = F.softmax(obj_dists, dim=1).data.view(-1)[
            obj_preds.data + arange(obj_preds.data) * self.num_classes]
        perm, inv_perm, ls_transposed = self.sort_rois(im_inds.data, confidence, box_priors)

        edge_input_packed = PackedSequence(inp_feats[perm], ls_transposed)
        edge_reps = self.edge_ctx_rnn(edge_input_packed)[0][0]

        # now we're good! unperm
        edge_ctx = edge_reps[inv_perm]
        return edge_ctx
示例#2
0
    def forward(self, last_outputs, obj_dists, rel_inds, im_inds, rois, boxes):

        twod_inds = arange(last_outputs.obj_preds.data
                           ) * self.num_classes + last_outputs.obj_preds.data
        obj_scores = F.softmax(last_outputs.rm_obj_dists,
                               dim=1).view(-1)[twod_inds]

        rel_rep, _ = F.softmax(last_outputs.rel_dists, dim=1)[:, 1:].max(1)
        rel_scores_argmaxed = rel_rep * obj_scores[
            rel_inds[:, 0]] * obj_scores[rel_inds[:, 1]]
        _, rel_scores_idx = torch.sort(rel_scores_argmaxed.view(-1),
                                       dim=0,
                                       descending=True)
        rel_scores_idx = rel_scores_idx[:100]

        filtered_rel_inds = rel_inds[rel_scores_idx.data]

        obj_fmap = self.obj_feature_map(last_outputs.fmap.detach(), rois)

        rm_obj_dists, obj_preds = self.context(
            obj_fmap, obj_dists.detach(), im_inds,
            last_outputs.rm_obj_labels if self.mode == 'predcls' else None,
            boxes.data, last_outputs.boxes_all)

        obj_dtype = obj_fmap.data.type()
        obj_preds_embeds = torch.index_select(self.ort_embedding, 0,
                                              obj_preds).type(obj_dtype)
        transfered_boxes = torch.stack(
            (boxes[:, 0] / IM_SCALE, boxes[:, 3] / IM_SCALE,
             boxes[:, 2] / IM_SCALE, boxes[:, 1] / IM_SCALE,
             ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])) /
             (IM_SCALE**2)), -1).type(obj_dtype)
        obj_features = torch.cat(
            (obj_fmap, obj_preds_embeds, transfered_boxes), -1)
        edge_rep = self.post_emb(obj_features)

        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0][filtered_rel_inds[:, 1]]
        obj_rep = edge_rep[:, 1][filtered_rel_inds[:, 2]]

        prod_rep = subj_rep * obj_rep

        vr = self.visual_rep(last_outputs.fmap.detach(), rois,
                             filtered_rel_inds[:, 1:])

        prod_rep = prod_rep * vr

        rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            rel_dists = rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    obj_preds[filtered_rel_inds[:, 1]],
                    obj_preds[filtered_rel_inds[:, 2]],
                ), 1))

        return filtered_rel_inds, rm_obj_dists, obj_preds, rel_dists
示例#3
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True, num_sample_per_gt=1)
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        visual_rep = self.visual_rep(result.fmap, rois, rel_inds[:, 1:])

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Now do the approximation WHEREVER THERES A VALID RELATIONSHIP.
        result.rm_obj_dists, result.rel_dists = self.message_pass(
            F.relu(self.edge_unary(visual_rep)), self.obj_unary(result.obj_fmap), rel_inds[:, 1:])

        # result.box_deltas_update = box_deltas

        if self.training:
            return result

        # Decode here ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        if self.mode == 'predcls':
            # Hack to get the GT object labels
            result.obj_scores = result.rm_obj_dists.data.new(gt_classes.size(0)).fill_(1)
            result.obj_preds = gt_classes.data[:, 1]
        elif self.mode == 'sgdet':
            order, obj_scores, obj_preds= filter_det(F.softmax(result.rm_obj_dists),
                                                              result.boxes_all,
                                                              start_ind=0,
                                                              max_per_img=100,
                                                              thresh=0.00,
                                                              pre_nms_topn=6000,
                                                              post_nms_topn=300,
                                                              nms_thresh=0.3,
                                                              nms_filter_duplicates=True)
            idx, perm = torch.sort(order)
            result.obj_preds = rel_inds.new(result.rm_obj_dists.size(0)).fill_(1)
            result.obj_scores = result.rm_obj_dists.data.new(result.rm_obj_dists.size(0)).fill_(0)
            result.obj_scores[idx] = obj_scores.data[perm]
            result.obj_preds[idx] = obj_preds.data[perm]
        else:
            scores_nz = F.softmax(result.rm_obj_dists).data
            scores_nz[:, 0] = 0.0
            result.obj_scores, score_ord = scores_nz[:, 1:].sort(dim=1, descending=True)
            result.obj_preds = score_ord[:,0] + 1
            result.obj_scores = result.obj_scores[:,0]

        result.obj_preds = Variable(result.obj_preds)
        result.obj_scores = Variable(result.obj_scores)

        # Set result's bounding boxes to be size
        # [num_boxes, topk, 4] instead of considering every single object assignment.
        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data

        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors
        rel_rep = F.softmax(result.rel_dists)

        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep)
示例#4
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

            if test:
            prob dists, boxes, img inds, maxscores, classes

        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        # rel_feat = self.relationship_feat.feature_map(x)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        spt_feats = self.get_boxes_encode(boxes, rel_inds)
        pair_inds = self.union_pairs(im_inds)

        if self.hook_for_grad:
            rel_inds = gt_rels[:, :-1].data

        if self.hook_for_grad:
            fmap = result.fmap
            fmap.register_hook(self.save_grad)
        else:
            fmap = result.fmap.detach()

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(fmap, rois)
        # result.obj_dists_head = self.obj_classify_head(obj_fmap_rel)

        obj_embed = F.softmax(result.rm_obj_dists,
                              dim=1) @ self.obj_embed.weight
        obj_embed_lstm = F.softmax(result.rm_obj_dists,
                                   dim=1) @ self.embeddings4lstm.weight
        pos_embed = self.pos_embed(Variable(center_size(boxes.data)))
        obj_pre_rep = torch.cat((result.obj_fmap, obj_embed, pos_embed), 1)
        obj_feats = self.merge_obj_feats(obj_pre_rep)
        # obj_feats=self.trans(obj_feats)
        obj_feats_lstm = torch.cat(
            (obj_feats, obj_embed_lstm),
            -1).contiguous().view(1, obj_feats.size(0), -1)

        # obj_feats = F.relu(obj_feats)

        phr_ori = self.visual_rep(fmap, rois, pair_inds[:, 1:])
        vr_indices = torch.from_numpy(
            intersect_2d(rel_inds[:, 1:].cpu().numpy(),
                         pair_inds[:, 1:].cpu().numpy()).astype(
                             np.uint8)).cuda().max(-1)[1]
        vr = phr_ori[vr_indices]

        phr_feats_high = self.get_phr_feats(phr_ori)

        obj_feats_lstm_output, (obj_hidden_states,
                                obj_cell_states) = self.lstm(obj_feats_lstm)

        rm_obj_dists1 = result.rm_obj_dists + self.context.decoder_lin(
            obj_feats_lstm_output.squeeze())
        obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \
                            phr_feats_high, im_inds, pair_inds)

        obj_embed_lstm1 = F.softmax(rm_obj_dists1,
                                    dim=1) @ self.embeddings4lstm.weight

        obj_feats_lstm1 = torch.cat((obj_feats_output, obj_embed_lstm1), -1).contiguous().view(1, \
                            obj_feats_output.size(0), -1)
        obj_feats_lstm_output, _ = self.lstm(
            obj_feats_lstm1, (obj_hidden_states, obj_cell_states))

        rm_obj_dists2 = rm_obj_dists1 + self.context.decoder_lin(
            obj_feats_lstm_output.squeeze())
        obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \
                            phr_feats_high, im_inds, pair_inds)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds = self.context(
            rm_obj_dists2, obj_feats_output, result.rm_obj_labels
            if self.training or self.mode == 'predcls' else None, boxes.data,
            result.boxes_all)

        obj_dtype = result.obj_fmap.data.type()
        obj_preds_embeds = torch.index_select(self.ort_embedding, 0,
                                              result.obj_preds).type(obj_dtype)
        tranfered_boxes = torch.stack(
            (boxes[:, 0] / IM_SCALE, boxes[:, 3] / IM_SCALE,
             boxes[:, 2] / IM_SCALE, boxes[:, 1] / IM_SCALE,
             ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])) /
             (IM_SCALE**2)), -1).type(obj_dtype)
        obj_features = torch.cat(
            (result.obj_fmap, obj_preds_embeds, tranfered_boxes), -1)
        obj_features_merge = self.merge_obj_low(
            obj_features) + self.merge_obj_high(obj_feats_output)

        # Split into subject and object representations
        result.subj_rep = self.post_emb_s(obj_features_merge)[rel_inds[:, 1]]
        result.obj_rep = self.post_emb_o(obj_features_merge)[rel_inds[:, 2]]
        prod_rep = result.subj_rep * result.obj_rep

        # obj_pools = self.visual_obj(result.fmap.detach(), rois, rel_inds[:, 1:])
        # rel_pools = self.relationship_feat.union_rel_pooling(rel_feat, rois, rel_inds[:, 1:])
        # context_pools = torch.cat([obj_pools, rel_pools], 1)
        # merge_pool = self.merge_feat(context_pools)
        # vr = self.roi_fmap(merge_pool)

        # vr = self.rel_refine(vr)

        prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        prod_rep = torch.cat((prod_rep, spt_feats), -1)
        freq_gate = self.freq_gate(prod_rep)
        freq_gate = F.sigmoid(freq_gate)
        result.rel_dists = self.rel_compress(prod_rep)
        # result.rank_factor = self.ranking_module(prod_rep).view(-1)

        if self.use_bias:
            result.rel_dists = result.rel_dists + freq_gate * self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # rel_rep = smooth_one_hot(rel_rep)
        # rank_factor = F.sigmoid(result.rank_factor)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#5
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")
        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors.detach()   

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None
            # rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)
            rel_labels_neg = self.get_neg_examples(result.rel_labels)
            rel_inds_neg = rel_labels_neg[:,:3]

        #torch.cat((result.rel_labels[:,0].contiguous().view(236,1),result.rm_obj_labels[result.rel_labels[:,1]].view(236,1),result.rm_obj_labels[result.rel_labels[:,2]].view(236,1),result.rel_labels[:,3].contiguous().view(236,1)),-1)
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)  #[275,3], [im_inds, box1_inds, box2_inds]

        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois.detach()) # detach: prevent backforward flowing

        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,   # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(),  # .detach:Returns a new Variable, detached from the current graph
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all.detach() if self.mode == 'sgdet' else result.boxes_all)
        

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else: 
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]
     
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  #[384,2,4096]
        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # prod_rep, rel_inds: [275,4096], [275,3]
    

        if self.use_vision: # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) 
            else:
                prod_rep = prod_rep * vr  # [275,4096]
                if self.training:
                    vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                    prod_rep_neg = subj_rep[rel_inds_neg[:, 1]].detach() * obj_rep[rel_inds_neg[:, 2]].detach() * vr_neg 
                    rel_dists_neg = self.rel_compress(prod_rep_neg)
                    

        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.rm_obj_preds[rel_inds[:, 1]],
                result.rm_obj_preds[rel_inds[:, 2]],
            ), 1))


        if self.training:
            judge = result.rel_labels.data[:,3] != 0
            if judge.sum() != 0:  # gt_rel exit in rel_inds
                select_rel_inds = torch.arange(rel_inds.size(0)).view(-1,1).long().cuda()[result.rel_labels.data[:,3] != 0]
                com_rel_inds = rel_inds[select_rel_inds]
                twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                result.obj_scores = F.softmax(result.rm_obj_dists.detach(), dim=1).view(-1)[twod_inds]   # only 1/4 of 384 obj_dists will be updated; because only 1/4 objs's labels are not 0

                # positive overall score
                obj_scores0 = result.obj_scores[com_rel_inds[:,1]]
                obj_scores1 = result.obj_scores[com_rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists[select_rel_inds], dim=1)    # result.rel_dists has grad
                _, pred_classes_argmax = rel_rep.data[:,:].max(1)  # all classes
                max_rel_score = rel_rep.gather(1, Variable(pred_classes_argmax.view(-1,1))).squeeze()  # SqueezeBackward, GatherBackward
                score_list = torch.cat((com_rel_inds[:,0].float().contiguous().view(-1,1), obj_scores0.data.view(-1,1), obj_scores1.data.view(-1,1), max_rel_score.data.view(-1,1)), 1)
                prob_score = max_rel_score * obj_scores0.detach() * obj_scores1.detach()
                #pos_prob[:,1][result.rel_labels.data[:,3] == 0] = 0  # treat most rel_labels as neg because their rel cls is 0 "unknown"  
                
                # negative overall score
                obj_scores0_neg = result.obj_scores[rel_inds_neg[:,1]]
                obj_scores1_neg = result.obj_scores[rel_inds_neg[:,2]]
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)   # rel_dists_neg has grad
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,:].max(1)  # all classes
                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                score_list_neg = torch.cat((rel_inds_neg[:,0].float().contiguous().view(-1,1), obj_scores0_neg.data.view(-1,1), obj_scores1_neg.data.view(-1,1), max_rel_score_neg.data.view(-1,1)), 1)
                prob_score_neg = max_rel_score_neg * obj_scores0_neg.detach() * obj_scores1_neg.detach()

                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                all_rel_inds = torch.cat((result.rel_labels.data[select_rel_inds], rel_labels_neg), 0)  # [#pos_inds+#neg_inds, 4]
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                score_list_all = torch.cat((score_list,score_list_neg), 0) 
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_rel_inds = all_rel_inds[sort_prob_inds]
                sorted_flag = flag[sort_prob_inds].squeeze()  # can be used to check distribution of pos and neg
                sorted_score_list_all = score_list_all[sort_prob_inds]
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable
                
                # positive triplet and score list
                pos_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                pos_trips = torch.cat((pos_sorted_inds[:,0].contiguous().view(-1,1), result.rm_obj_labels.data.view(-1,1)[pos_sorted_inds[:,1]], result.rm_obj_labels.data.view(-1,1)[pos_sorted_inds[:,2]], pos_sorted_inds[:,3].contiguous().view(-1,1)), 1)
                pos_score_list = sorted_score_list_all.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable 

                # negative triplet and score list
                neg_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                neg_trips = torch.cat((neg_sorted_inds[:,0].contiguous().view(-1,1), result.rm_obj_labels.data.view(-1,1)[neg_sorted_inds[:,1]], result.rm_obj_labels.data.view(-1,1)[neg_sorted_inds[:,2]], neg_sorted_inds[:,3].contiguous().view(-1,1)), 1)
                neg_score_list = sorted_score_list_all.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable
                
                
                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.size(0))[:,None].expand_as(torch.Tensor(pos_exp.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                # some variables .register_hook(extract_grad)

                return result

            else:  # no gt_rel in rel_inds
                print("no gt_rel in rel_inds!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                result.obj_scores = F.softmax(result.rm_obj_dists.detach(), dim=1).view(-1)[twod_inds]

                # positive overall score
                obj_scores0 = result.obj_scores[rel_inds[:,1]]
                obj_scores1 = result.obj_scores[rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
                _, pred_classes_argmax = rel_rep.data[:,:].max(1)  # all classes
                max_rel_score = rel_rep.gather(1, Variable(pred_classes_argmax.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                prob_score = max_rel_score * obj_scores0.detach() * obj_scores1.detach()
                #pos_prob[:,1][result.rel_labels.data[:,3] == 0] = 0  # treat most rel_labels as neg because their rel cls is 0 "unknown"  
                
                # negative overall score
                obj_scores0_neg = result.obj_scores[rel_inds_neg[:,1]]
                obj_scores1_neg = result.obj_scores[rel_inds_neg[:,2]]
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)   
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,:].max(1)  # all classes
                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                prob_score_neg = max_rel_score_neg * obj_scores0_neg.detach() * obj_scores1_neg.detach()

                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                all_rel_inds = torch.cat((result.rel_labels.data, rel_labels_neg), 0)  # [#pos_inds+#neg_inds, 4]
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_rel_inds = all_rel_inds[sort_prob_inds]
                sorted_flag = flag[sort_prob_inds].squeeze()  # can be used to check distribution of pos and neg
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable

                pos_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                neg_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable  
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable

                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.data.size(0))[:,None].expand_as(torch.Tensor(pos_exp.data.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.data.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())

                return result
        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
        
        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores,
                           result.rm_obj_preds, rel_inds[:, 1:],
                           rel_rep)
示例#6
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):

        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors
        # boxes = result.boxes_assigned
        boxes_deltas = result.rm_box_deltas # sgcls is None
        boxes_all = result.boxes_all # sgcls is None

        if (self.training) and (result.rel_labels is None):
            import pdb; pdb.set_trace()
            print('debug')
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, result.rm_obj_dists.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes, result.rm_obj_dists.data)

        reward_rel_inds = None
        if self.mode == 'sgdet':
            msg_rel_inds = self.get_msg_rel_inds(im_inds, boxes, result.rm_obj_dists.data)
            reward_rel_inds = self.get_reward_rel_inds(im_inds, boxes, result.rm_obj_dists.data)


        if self.mode == 'sgdet':
            result.rm_obj_dists_list, result.obj_preds_list, result.rel_dists_list, result.bbox_list, result.offset_list, \
                result.rel_dists, result.obj_preds, result.boxes_all, result.all_rel_logits = self.context(
                                            result.fmap.detach(), result.rm_obj_dists.detach(), im_inds, rel_inds, msg_rel_inds, 
                                            reward_rel_inds, im_sizes, boxes.detach(), boxes_deltas.detach(), boxes_all.detach(),
                                            result.rm_obj_labels if self.training or self.mode == 'predcls' else None)

        elif self.mode in ['sgcls', 'predcls']:
            result.obj_preds, result.rm_obj_logits, result.rel_logits = self.context(
                                            result.fmap.detach(), result.rm_obj_dists.detach(),
                                            im_inds, rel_inds, None, None, im_sizes, boxes.detach(), None, None,
                                            result.rm_obj_labels if self.training or self.mode == 'predcls' else None)
        else:
            raise NotImplementedError

        # result.rm_obj_dists = result.rm_obj_dists_list[-1]

        if self.training:
            return result

        if self.mode == 'predcls':
            import pdb; pdb.set_trace()
            print('debug..')
            result.obj_preds = result.rm_obj_labels
            result.obj_scores = Variable(torch.from_numpy(np.ones(result.obj_preds.shape[0],)).float().cuda())
        else:
            twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
            result.obj_scores = F.softmax(result.rm_obj_logits, dim=1).view(-1)[twod_inds]

        # # Bbox regression
        if self.mode == 'sgdet':
            if conf.use_postprocess:
                bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
            else:
                bboxes = result.rm_box_priors
        else:
            bboxes = result.rm_box_priors

        rel_scores = F.sigmoid(result.rel_logits)

        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_scores)
示例#7
0
def filter_dets(boxes, obj_scores, obj_classes, rel_inds, pred_scores, obj1,
                obj2, rel_emb):
    """
    Filters detectionsa according to the score product: obj1 * obj2 * rel
    :param boxes: [num_box, topk, 4] if bbox regression else [num_box, 4]
    :param obj_scores: [num_box] probabilities for the scores
    :param obj_classes: [num_box] class labels for the topk
    :param rel_inds: [num_rel, 2] TENSOR consisting of (im_ind0, im_ind1)
    :param pred_scores: [topk, topk, num_rel, num_predicates]
    :param use_nms: True if use NMS to filter dets.
    :return: boxes, objs, rels, pred_scores

    """
    if boxes.dim() != 2:
        raise ValueError("Boxes needs to be [num_box, 4] but its {}".format(
            boxes.size()))

    num_box = boxes.size(0)  # 64
    assert obj_scores.size(0) == num_box  # 64

    assert obj_classes.size() == obj_scores.size()  # 64
    num_rel = rel_inds.size(0)  # 275
    assert rel_inds.size(1) == 2
    #assert pred_scores.size(0) == num_rel

    #obj_scores0 = obj_scores.data[rel_inds[:,0]]
    #obj_scores1 = obj_scores.data[rel_inds[:,1]]

    #pred_scores_max, pred_classes_argmax = pred_scores.data[:,1:].max(1)
    #pred_classes_argmax = pred_classes_argmax + 1
    # get maximum score among 150/50 classes, for single obj and rel seperately, then product and sort

    num_trip = rel_inds.size(0)
    num_classes = obj1.size(1)
    num_rels = rel_emb.size(1)
    embdim = rel_emb.size(2)

    two_inds1 = arange(obj_classes.data[
        rel_inds[:, 0]]) * num_classes + obj_classes.data[rel_inds[:, 0]]
    two_inds2 = arange(obj_classes.data[
        rel_inds[:, 1]]) * num_classes + obj_classes.data[rel_inds[:, 1]]
    obj1emb = obj1.view(-1, embdim)[two_inds1]  # (275,151,10) -> (275, 10)
    obj2emb = obj2.view(-1, embdim)[two_inds2]  # (275,151,10) -> (275, 10)

    d = obj1emb - obj2emb  # (275, 10)
    d = d[:, None, :].expand_as(rel_emb)  # (275, 51, 10), copy 51 times
    d = d + rel_emb  # (275, 51, 10)
    d = d.view(-1, embdim)  # (275*51, 10)
    d = torch.squeeze(torch.sqrt(d.pow(2).sum(1)))  # (275*51, 1) -> (275*51)

    rel_surgery = []

    for i in range(num_trip):
        start = num_rels * i
        end = num_rels * (i + 1)
        min_d, min_ind = d[start + 1:end].min(0)  # ignore "unknown" ind = 0
        rel_surgery.append([min_ind.data + 1, min_d.data
                            ])  # restore the relationship class value

    rel_surgery = np.array(rel_surgery)
    rel_pred = torch.from_numpy(rel_surgery[:,
                                            0])  # double tensor, float element
    distance = torch.from_numpy(rel_surgery[:,
                                            1])  # double tensor, float element

    #rel_scores_argmaxed = pred_scores_max * obj_scores0 * obj_scores1
    #rel_scores_vs, rel_scores_idx = torch.sort(rel_scores_argmaxed.view(-1), dim=0, descending=True)
    # rel_scores_vs: sorted distance, double tensor; rel_scores_idx: long tensor
    rel_scores_vs, rel_scores_idx = torch.sort(distance.contiguous().view(-1),
                                               dim=0,
                                               descending=False)

    # boxes_out: rois incorporated deltas
    # objs_np: rm_obj_preds from decoder rnn
    # obj_scores_np: rm_obj_dists from decoder rnn
    # rels: rel_inds from boxes overlapped; after surgery, sorted by overall_score / distance
    # pred_scores_sorted: extracted by max() among 50 cls, then rel scores sorted by overall scores
    rels = rel_inds[
        rel_scores_idx.cuda()].cpu().numpy()  # rel_inds is "cuda" long tensor
    sorted_rel_pred = rel_pred[rel_scores_idx].cpu().numpy().astype(
        int
    )  # if it's float, when column stack it with rels, the entity will be float
    #pred_scores_sorted = pred_scores[rel_scores_idx].data.cpu().numpy()
    pred_scores_sorted = None
    obj_scores_np = obj_scores.data.cpu().numpy()
    objs_np = obj_classes.data.cpu().numpy()
    boxes_out = boxes.data.cpu().numpy()

    return boxes_out, objs_np, obj_scores_np, rels, pred_scores_sorted, sorted_rel_pred
示例#8
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False,
                depth_imgs=None):
        """
        Forward pass for relation detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: a numpy array of (h, w, scale) for each image.
        :param image_offset: oOffset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param gt_rels: [] gt relations
        :param proposals: region proposals retrieved from file
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :param return_fmap: if the object detector must return the extracted feature maps
        :param depth_imgs: depth images [batch_size, 1, IM_SIZE, IM_SIZE]
        """

        # -- Get prior `result` object (instead of calling faster-rcnn-detector)
        result = self.get_prior_results(image_offset, gt_boxes, gt_classes,
                                        gt_rels)

        # -- Get RoI and relations
        rois, rel_inds = self.get_rois_and_rels(result, image_offset, gt_boxes,
                                                gt_classes, gt_rels)

        # -- Determine subject and object indices
        subj_inds = rel_inds[:, 1]
        obj_inds = rel_inds[:, 2]

        # -- Extract features from depth backbone
        depth_features = self.depth_backbone(depth_imgs)

        # -- Prevent the gradients from flowing back to depth backbone (Pre-trained mode)
        if self.pretrained_depth:
            depth_features = depth_features.detach()

        # -- Extract RoI features for relation detection
        depth_rois_features = self.get_roi_features_depth(depth_features, rois)

        # -- Create a pairwise relation vector out of location features
        rel_depth = torch.cat(
            (depth_rois_features[subj_inds], depth_rois_features[obj_inds]), 1)
        rel_depth_fc = self.depth_rel_hlayer(rel_depth)

        # -- Predict relation distances
        result.rel_dists = self.depth_rel_out(rel_depth_fc)

        # --- *** END OF ARCHITECTURE *** ---#

        # -- Prepare object predictions vector (PredCLS)
        # Assuming its predcls
        obj_labels = result.rm_obj_labels if self.training or self.mode == 'predcls' else None
        # One hot vector of objects
        result.rm_obj_dists = Variable(
            to_onehot(obj_labels.data, self.num_classes))
        # Indexed vector
        result.obj_preds = obj_labels if obj_labels is not None else result.rm_obj_dists[:, 1:].max(
            1)[1] + 1

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Boxes will get fixed by filter_dets function.
        bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # Filtering: Subject_Score * Pred_score * Obj_score, sorted and ranked
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#9
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        #rcnn_pred = result.rm_obj_dists[:, 1:].max(1)[1] + 1  # +1: because the index is in 150-d but truth is 151-d
        #rcnn_ap = torch.mean((rcnn_pred == result.rm_obj_labels).float().cpu())

        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors.detach()

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'  # sgcls's result.rel_labels is gt and not None
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(
            result.rel_labels, im_inds,
            boxes)  #[275,3], [im_inds, box1_inds, box2_inds]

        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(
            result.fmap.detach(),
            rois.detach())  # detach: prevent backforward flowing

        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,  # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(
            ),  # .detach:Returns a new Variable, detached from the current graph
            im_inds,
            result.rm_obj_labels
            if self.training or self.mode == 'predcls' else None,
            boxes.data,
            result.boxes_all.detach()
            if self.mode == 'sgdet' else result.boxes_all)

        #lstm_ap = torch.mean((result.rm_obj_preds == result.rm_obj_labels).float().cpu())
        #fg_ratio = result.rm_obj_labels.nonzero().size(0) / result.rm_obj_labels.size(0)
        #lst = [rcnn_ap.data.numpy(), lstm_ap.data.numpy(), fg_ratio]
        #a = torch.stack((rel_inds[:64, 1], rel_inds[:64, 2]), 0)
        #b = np.unique(a.cpu().numpy())
        #print(len(b) / 64)

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else:
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2,
                                 self.pooling_dim)  #[384,2,4096]

        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]

        # prod_rep, rel_inds: [275,4096], [275,3]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        if self.use_vision:  # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                prod_rep = prod_rep * vr  # [275,4096]

        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.rm_obj_preds[rel_inds[:, 1]],
                    result.rm_obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data
                           ) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]  # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)  # [275, 51]

        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores, result.rm_obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#10
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """Forward pass for detection
        Args:
            x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
            im_sizes: A numpy array of (h, w, scale) for each image.
            image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)

            Training parameters:
            gt_boxes: [num_gt, 4] GT boxes over the batch.
            gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
            gt_rels:
            proposals:
            train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
            return_fmap:

        Returns:
            If train:
                scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            If test:
                prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        """
        Results attributes:
            od_obj_dists: digits after score_fc in RCNN
            rm_obj_dists: od_obj_dists after nms
            obj_scores: nmn 
            obj_preds=None, 
            obj_fmap=None,
            od_box_deltas=None, 
            rm_box_deltas=None,
            od_box_targets=None, 
            rm_box_targets=None, 
            od_box_priors: proposal before nms
            rm_box_priors: proposal after nms
            boxes_assigned=None, 
            boxes_all=None, 
            od_obj_labels=None, 
            rm_obj_labels=None,
            rpn_scores=None, 
            rpn_box_deltas=None, 
            rel_labels=None,
            im_inds: image index of every proposals
            fmap=None, 
            rel_dists=None, 
            rel_inds=None, 
            rel_rep=None
            
            one example:
           sgcls task: 
            result.fmap: torch.Size([6, 512, 37, 37])
result.im_inds: torch.Size([44])
result.obj_fmap: torch.Size([44, 4096])
result.od_box_priors: torch.Size([44, 4])
result.od_obj_dists: torch.Size([44, 151])
result.od_obj_labels: torch.Size([44])
result.rel_labels: torch.Size([316, 4])
result.rm_box_priors: torch.Size([44, 4])
result.rm_obj_dists: torch.Size([44, 151])
result.rm_obj_labels: torch.Size([44])
        """
        if result.is_none():
            return ValueError("heck")

        # image_offset refer to Blob
        # self.batch_size_per_gpu * index
        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        #embed(header='rel_model.py before rel_assignments')
        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'

            # only in sgdet mode

            # shapes:
            # im_inds: (box_num,)
            # boxes: (box_num, 4)
            # rm_obj_labels: (box_num,)
            # gt_boxes: (box_num, 4)
            # gt_classes: (box_num, 2) maybe[im_ind, class_ind]
            # gt_rels: (rel_num, 4)
            # image_offset: integer
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)
        #embed(header='rel_model.py after rel_assignments')

        # rel_labels[:, :3] if sgcls
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        # obj_fmap: (NumOfRoI, 4096)
        # RoIAlign
        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]
        # embed(header='rel_model.py prod_rep')

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        #embed(header='rel model return ')
        if self.training:
            # embed(header='rel_model.py before return')
            # what will be useful:
            # rm_obj_dists, rm_obj_labels
            # rel_labels, rel_dists
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        #embed(header='rel_model.py before return')
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#11
0
文件: rel_model3.py 项目: ht014/lsbr
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, node_rep0 = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        edge_rep = node_rep0.repeat(1, 2)
        edge_rep = edge_rep.view(edge_rep.size(0), 2, -1)

        global_feature = self.global_embedding(result.fmap.detach())
        result.global_dists = self.global_logist(global_feature)
        one_hot_multi = torch.zeros(
            (result.global_dists.shape[0], self.num_classes))

        one_hot_multi[im_inds, result.rm_obj_labels] = 1.0
        result.multi_hot = one_hot_multi.float().cuda()

        subj_global_additive_attention = F.relu(
            self.global_sub_additive(edge_rep[:, 0] + global_feature[im_inds]))
        obj_global_additive_attention = F.relu(
            self.global_obj_additive(edge_rep[:, 1] + global_feature[im_inds]))

        subj_rep = edge_rep[:,
                            0] + subj_global_additive_attention * global_feature[
                                im_inds]
        obj_rep = edge_rep[:,
                           1] + obj_global_additive_attention * global_feature[
                               im_inds]

        if self.training:
            self.centroids = self.disc_center.centroids.data

        # if edge_ctx is None:
        #     edge_rep = self.post_emb(result.obj_preds)
        # else:
        #     edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        # edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)
        #
        # subj_rep = edge_rep[:, 0]
        # obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

        prod_rep = prod_rep * vr

        prod_rep = F.tanh(prod_rep)

        logits, self.direct_memory_feature = self.meta_classify(
            prod_rep, self.centroids)
        # result.rel_dists = self.rel_compress(prod_rep)
        result.rel_dists = logits
        result.rel_dists2 = self.direct_memory_feature[-1]
        # result.hallucinate_logits = self.direct_memory_feature[-1]
        if self.training:
            result.center_loss = self.disc_center(
                prod_rep, result.rel_labels[:, -1]) * 0.01

        if self.use_bias:
            result.rel_dists = result.rel_dists + 1.0 * self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#12
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2.0] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2.0] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

            if test:
            prob dists, boxes, img inds, maxscores, classes

        """
        '---------gt_rel process----------'

        batch_size = x.shape[0]

        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")


        # Prevent gradients from flowing back into score_fc from elsewhere   the last 3 is bullshit
        if self.mode == 'sgdet':
            im_inds = result.im_inds - image_offset  # all indices
            boxes = result.od_box_priors  # all boxes
            rois = torch.cat((im_inds[:, None].float(), boxes), 1)
            result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)
            result.rm_obj_dists, result.obj_preds,  im_inds, result.rm_box_priors, result.rm_obj_labels, rois, result.boxes_all, \
                = self.context(
                result.obj_fmap,
                result.rm_obj_dists.detach(),
                im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
                boxes.data, result.boxes_all, batch_size,
                rois, result.od_box_deltas.detach(), im_sizes, image_offset, gt_classes, gt_boxes)

            boxes = result.rm_box_priors
        else:
            #sgcls and predcls
            im_inds = result.im_inds - image_offset
            boxes = result.rm_box_priors
            rois = torch.cat((im_inds[:, None].float(), boxes), 1)
            result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

            result.rm_obj_dists, result.obj_preds = self.context(
                result.obj_fmap,
                result.rm_obj_dists.detach(),
                im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
                boxes.data, result.boxes_all, batch_size,
                None, None, None, None, None, None)

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        # visual part
        obj_pooling = self.obj_avg_pool(result.fmap.detach(), rois).view(-1, 512)
        subj_rep = obj_pooling[rel_inds[:, 1]]
        obj_rep = obj_pooling[rel_inds[:, 2]]
        vr, union_rois = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
        vr = vr.view(-1, 512)
        x_visual = torch.cat((subj_rep, obj_rep, vr), 1)
        # semantic part
        subj_class = result.obj_preds[rel_inds[:, 1]]
        obj_class = result.obj_preds[rel_inds[:, 2]]
        subj_emb = self.obj_embed(subj_class)
        obj_emb = self.obj_embed2(obj_class)
        x_semantic = torch.cat((subj_emb, obj_emb), 1)

        # padding
        perm, inv_perm, ls_transposed = self.sort_rois(rel_inds[:, 0].data, None, union_rois[:, 1:])
        x_visual_rep = x_visual[perm].contiguous()
        x_semantic_rep = x_semantic[perm].contiguous()

        visual_input = PackedSequence(x_visual_rep, torch.tensor(ls_transposed))
        inputs1, lengths1 = pad_packed_sequence(visual_input, batch_first=False)
        semantic_input = PackedSequence(x_semantic_rep, torch.tensor(ls_transposed))
        inputs2, lengths2 = pad_packed_sequence(semantic_input, batch_first=False)


        self.hidden_state_visual = self.init_hidden(batch_size, bidirectional=False)
        self.hidden_state_semantic = self.init_hidden(batch_size, bidirectional=False)

        output1, self.hidden_state_visual = self.lstm_visual(inputs1, self.hidden_state_visual)
        output2, self.hidden_state_semantic = self.lstm_semantic(inputs2, self.hidden_state_semantic)        
        inputs = torch.cat((output1, output2), 2)


        x_fusion = self.odeBlock(inputs, batch_size)
        x_fusion = x_fusion[1]

        x_fusion, _ = pack_padded_sequence(x_fusion, lengths1, batch_first=False)
        x_out = self.fc_predicate(x_fusion)
        result.rel_dists = x_out[inv_perm]  # for evaluation and crossentropy

        if self.training:
            return result

        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep)
示例#13
0
文件: rel_model.py 项目: galsina/lml
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return:
            (scores, boxdeltas, labels, boxes, boxtargets,
                rpnscores, rpnboxes, rellabels)
            (prob dists, boxes, img inds, maxscores, classes)

        """

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        # if self.training and result.rel_labels is None:
        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        result.subj_rep_rel = subj_rep[rel_inds[:, 1]]
        result.obj_rep_rel = obj_rep[rel_inds[:, 2]]
        result.prod_rep = result.subj_rep_rel * result.obj_rep_rel

        if self.use_vision:
            result.vr = self.visual_rep(result.fmap.detach(), rois,
                                        rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                import ipdb
                ipdb.set_trace()
                result.prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                result.prod_rep = result.prod_rep * result.vr

        if self.use_tanh:
            result.prod_rep = F.tanh(result.prod_rep)

        result.rel_dists = self.rel_compress(result.prod_rep)

        result.obj_from = result.obj_preds[rel_inds[:, 1]]
        result.obj_to = result.obj_preds[rel_inds[:, 2]]

        if self.use_bias:
            result.rel_dists = result.rel_dists + \
                self.freq_bias.index_with_labels(
                    torch.stack((result.obj_from, result.obj_to), 1)
                )

        # if self.training:
        #     return result

        idxs, n_objs = np.unique(im_inds.data.cpu().numpy(),
                                 return_counts=True)

        if not self.training:
            assert len(n_objs) == 1
            n_rels = [len(rel_inds)]
        else:
            _, n_rels = np.unique(result.rel_labels.data[:, 0].cpu().numpy(),
                                  return_counts=True)
            n_rels = n_rels.tolist()

        preds = []
        obj_start = 0
        rel_start = 0
        result.obj_scores = []
        result.rel_reps = []
        for idx, n_obj, n_rel in zip(idxs, n_objs, n_rels):
            obj_end = obj_start + n_obj
            rel_end = rel_start + n_rel

            obj_preds_i = result.obj_preds.data[obj_start:obj_end]
            rm_obj_dists_i = result.rm_obj_dists[obj_start:obj_end]

            twod_inds = arange(obj_preds_i) * self.num_classes + obj_preds_i
            result.obj_scores.append(
                F.softmax(rm_obj_dists_i, dim=1).view(-1)[twod_inds])

            # Bbox regression
            if self.mode == 'sgdet':
                import ipdb
                ipdb.set_trace()  # TODO
                bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                    result.boxes_all.size(0), 4)
            else:
                # Boxes will get fixed by filter_dets function.
                bboxes_i = result.rm_box_priors[obj_start:obj_end]

            if self.mode == 'predcls':
                ((rm_obj_dists_i / 1000) + 1) / 2

            rel_dists_i = result.rel_dists[rel_start:rel_end]
            obj_preds_i = result.obj_preds[obj_start:obj_end]
            obj_scores_i = result.obj_scores[-1]
            if self.lml_topk is not None and self.lml_topk:
                if self.lml_softmax:
                    rel_rep = LML(N=self.lml_topk, branch=1000)(F.softmax(
                        rel_dists_i,
                        dim=1)[:, 1:].contiguous().view(-1)).view(n_rel, -1)
                else:
                    rel_rep = LML(N=self.lml_topk, branch=1000)(
                        rel_dists_i[:,
                                    1:].contiguous().view(-1)).view(n_rel, -1)

                rel_rep = torch.cat((
                    Variable(torch.zeros(n_rel, 1).type_as(rel_dists_i.data)),
                    rel_rep,
                ), 1)
            elif (self.entr_topk is not None
                  and self.entr_topk) or self.ml_loss:
                # Hack to ignore the background.
                rel_rep = torch.cat((Variable(
                    -1e10 * torch.ones(n_rel, 1).type_as(rel_dists_i.data)),
                                     rel_dists_i[:, 1:]), 1)
            else:
                rel_rep = F.softmax(rel_dists_i, dim=1)
            result.rel_reps.append(rel_rep)

            rel_inds_i = rel_inds[rel_start:rel_end, 1:].clone()
            rel_inds_i = rel_inds_i - rel_inds_i.min()  # Very hacky fix...

            # For debugging:
            # bboxes_i = bboxes_i.cpu()
            # obj_scores_i = obj_scores_i.cpu()
            # obj_preds_i = obj_preds_i.cpu()
            # rel_inds_i = rel_inds_i.cpu()
            # rel_rep = rel_rep.cpu()

            pred = filter_dets(bboxes_i, obj_scores_i, obj_preds_i, rel_inds_i,
                               rel_rep)
            preds.append(pred)

            obj_start += n_obj
            rel_start += n_rel

        assert obj_start == result.obj_preds.shape[0]
        assert rel_start == rel_inds.shape[0]
        return result, preds
示例#14
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        #rcnn_pred = result.rm_obj_dists[:, 1:].max(1)[1] + 1  # +1: because the index is in 150-d but truth is 151-d
        #rcnn_ap = torch.mean((rcnn_pred == result.rm_obj_labels).float().cpu())

        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors.detach()

        # Box and obj_dists APrecision
        obj_scores = F.softmax(result.rm_obj_dists, dim=1)
        result.rm_obj_preds = obj_scores.data[:, 1:].max(1)[1]
        result.rm_obj_preds = result.rm_obj_preds + 1
        twod_inds = arange(
            result.rm_obj_preds) * self.num_classes + result.rm_obj_preds
        bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
            result.boxes_all.size(0), 4)
        pred_to_gtbox = bbox_overlaps(bboxes.data, gt_boxes.data)
        im_inds = result.im_inds
        pred_to_gtbox[im_inds.data[:, None] != gt_classes.data[None, :,
                                                               0]] = 0.0
        max_overlaps, argmax_overlaps = pred_to_gtbox.max(1)
        labels = gt_classes[:, 1][argmax_overlaps]
        labels[max_overlaps < 0.5] = 0
        labels[result.rm_obj_preds != result.rm_obj_labels.data] = 0
        result.ratio = torch.nonzero(labels).size(0) / labels.size(0)
        return result.ratio
        """
示例#15
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_masks=None, gt_classes=None, gt_rels=None, pred_boxes=None, pred_masks=None,
                pred_fmaps=None, pred_dists=None):
        # pred_boxes: (#num, 5) im_ind box
        pred_im_inds = pred_boxes[:, 0].long() - image_offset
        pred_boxes = pred_boxes[:, 1:]
        gt_img_inds = gt_classes[:, 0] - image_offset
        rel_targets = None
        if self.training:
            # Assume that the GT boxes  and  pred_boxes
            # are already sorted in terms of image id
            num_images = int(pred_im_inds[-1]) + 1

            cls_targets = []
            bbox_targets = []

            for im_ind in range(num_images):
                g_inds = (gt_img_inds == im_ind).nonzero()
                if g_inds.dim() == 0:
                    continue
                g_inds = g_inds.squeeze(1)
                g_start = int(g_inds[0])
                g_end = int(g_inds[-1] + 1)

                t_inds = (pred_im_inds == im_ind).nonzero().squeeze(1)
                t_start = int(t_inds[0])
                t_end = int(t_inds[-1] + 1)

                # Max overlaps: for each predicted box, get the max ROI
                # Get the indices into the GT boxes too (must offset by the box start)
                ious = bbox_overlaps(pred_boxes[t_start:t_end], gt_boxes[g_start:g_end])
                max_overlaps, gt_assignment = ious.max(1)
                gt_assignment += g_start
                cls_target_ = gt_classes[:, 1][gt_assignment]
                bbox_target_ = gt_boxes[gt_assignment]
                cls_targets.append(cls_target_)
                bbox_targets.append(bbox_target_)
            cls_targets = torch.cat(cls_targets, 0)

            # based on mask
            # 根据预测的mask和gt mask计算相关iou,然后把gt rel(id序号都是gt mask中)变成rel_targets(id序号变成了预测的masks了)
            # gt_rels (2,3,5) 2和3是next,然后从pred mask中找和gt mask 2,3匹配的pred mask序号,组成新的关系对
            # rel_targets: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
            rel_targets = rel_assignments_with_mask(pred_im_inds.data, pred_masks.data, cls_targets.data,
                                                gt_masks.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        # rel_inds: [num_rels, 3] (img ind, box0 ind, box1 ind)
        rel_inds = self.get_rel_inds(rel_targets, pred_im_inds, pred_boxes)
        masks = F.avg_pool2d(pred_masks, 4, 4)
        pred_fmaps = pred_fmaps * masks[:, None, :, :]
        pred_fmaps = self.fc67(pred_fmaps.view(pred_boxes.size(0), -1))

        # Prevent gradients from flowing back into score_fc from elsewhere
        pred_dists, pred_classes, edge_ctx = self.context(
            pred_fmaps,
            pred_dists,
            pred_im_inds, cls_targets if self.training or self.mode == 'predcls' else None,
            pred_boxes.data, None)
        if edge_ctx is None:
            edge_rep = self.post_emb(pred_classes)
        else:
            edge_rep = self.post_lstm(edge_ctx)
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)
        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        rel_dists = self.rel_compress(prod_rep)

        if self.training:
            return pred_dists, cls_targets, rel_dists, rel_targets

        twod_inds = arange(pred_classes.data) * self.num_classes + pred_classes.data
        pred_scores = F.softmax(pred_dists, dim=1).view(-1)[twod_inds]

        rel_rep = F.softmax(rel_dists, dim=1)

        return filter_dets_mask(pred_boxes, pred_masks, pred_scores,
                                pred_classes, rel_inds[:, 1:], rel_rep)
示例#16
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        if self.use_ggnn_obj:
            result.rm_obj_dists = self.ggnn_obj_reason(
                im_inds, result.obj_fmap, result.rm_obj_labels
                if self.training or self.mode == 'predcls' else None)

        vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

        if self.use_ggnn_rel:
            result.rm_obj_dists, result.obj_preds, result.rel_dists = self.ggnn_rel_reason(
                obj_fmaps=result.obj_fmap,
                obj_logits=result.rm_obj_dists,
                vr=vr,
                rel_inds=rel_inds,
                obj_labels=result.rm_obj_labels
                if self.training or self.mode == 'predcls' else None,
                boxes_per_cls=result.boxes_all)
        else:
            result.rm_obj_dists, result.obj_preds, result.rel_dists = self.vr_fc_cls(
                obj_logits=result.rm_obj_dists,
                vr=vr,
                obj_labels=result.rm_obj_labels
                if self.training or self.mode == 'predcls' else None,
                boxes_per_cls=result.boxes_all)

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
示例#17
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")
        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors    # .detach()   

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None
            # rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        #torch.cat((result.rel_labels[:,0].contiguous().view(rel_inds.size(0),1),result.rm_obj_labels[result.rel_labels[:,1]].view(rel_inds.size(0),1),result.rm_obj_labels[result.rel_labels[:,2]].view(rel_inds.size(0),1),result.rel_labels[:,3].contiguous().view(rel_inds.size(0),1)),-1)
        #bbox_overlaps(boxes.data[55:57].contiguous().view(-1,1), boxes.data[8].contiguous().view(-1,1))
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)  #[275,3], [im_inds, box1_inds, box2_inds]
        
        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing

        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,   # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(),  # .detach:Returns a new Variable, detached from the current graph
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all if self.mode == 'sgdet' else result.boxes_all)
        

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else: 
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]
     
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  #[384,2,4096]
        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # prod_rep, rel_inds: [275,4096], [275,3]
    

        if self.use_vision: # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) 
            else:
                prod_rep = prod_rep * vr  # [275,4096]


        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.rm_obj_preds[rel_inds[:, 1]],
                result.rm_obj_preds[rel_inds[:, 2]],
            ), 1))

        # Attention: pos should use rm_obj_labes/rel_labels for obj/rel scores; neg should use rm_obj_preds/max_rel_score for obj/rel scores
        if self.training: 
            judge = result.rel_labels.data[:,3] != 0
            if judge.sum() != 0:  # gt_rel exit in rel_inds
                # positive overall score
                select_rel_inds = torch.arange(rel_inds.size(0)).view(-1,1).long().cuda()[result.rel_labels.data[:,3] != 0]
                com_rel_inds = rel_inds[select_rel_inds]
                twod_inds = arange(result.rm_obj_labels.data) * self.num_classes + result.rm_obj_labels.data  # dist: [-10,10]
                result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # only 1/4 of 384 obj_dists will be updated; because only 1/4 objs's labels are not 0
              
                obj_scores0 = result.obj_scores[com_rel_inds[:,1]]
                obj_scores1 = result.obj_scores[com_rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists[select_rel_inds], dim=1)    # result.rel_dists has grad
                rel_score = rel_rep.gather(1, result.rel_labels[select_rel_inds][:,3].contiguous().view(-1,1)).view(-1)  # not use squeeze(); SqueezeBackward, GatherBackward
                prob_score = rel_score * obj_scores0 * obj_scores1

                # negative overall score
                rel_cands = im_inds.data[:, None] == im_inds.data[None]
                rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0   # self relation = 0
                if self.require_overlap:     
                    rel_cands = rel_cands & (bbox_overlaps(boxes.data, boxes.data) > 0)   # Require overlap for detection
                rel_cands = rel_cands.nonzero()  # [#, 2]
                if rel_cands.dim() == 0:
                    print("rel_cands.dim() == 0!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    rel_cands = im_inds.data.new(1, 2).fill_(0) # shaped: [1,2], [0, 0]
                rel_cands = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) # rel_cands' value should be [0, 384]
                rel_inds_neg = rel_cands

                vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                subj_obj = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]]
                prod_rep_neg =  subj_obj * vr_neg
                rel_dists_neg = self.rel_compress(prod_rep_neg)
                all_rel_rep_neg = F.softmax(rel_dists_neg, dim=1)
                _, pred_classes_argmax_neg = all_rel_rep_neg.data[:,1:].max(1)
                pred_classes_argmax_neg = pred_classes_argmax_neg + 1
                all_rel_pred_neg = torch.cat((rel_inds_neg, pred_classes_argmax_neg.view(-1,1)), 1)
                ind_old = torch.ones(all_rel_pred_neg.size(0)).byte().cuda()
                for i in range(com_rel_inds.size(0)):    # delete those box pair with same rel type as pos triplets
                    ind_i = (all_rel_pred_neg[:,0] == com_rel_inds[i, 0]) & (all_rel_pred_neg[:,1] == com_rel_inds[i, 1]) & (result.rm_obj_preds.data[all_rel_pred_neg[:,1]] == result.rm_obj_labels.data[com_rel_inds[i, 1]]) & (all_rel_pred_neg[:,2] == com_rel_inds[i, 2]) & (result.rm_obj_preds.data[all_rel_pred_neg[:,2]] == result.rm_obj_labels.data[com_rel_inds[i, 2]]) & (all_rel_pred_neg[:,3] == result.rel_labels.data[select_rel_inds][i,3]) 
                    ind_i = (1 - ind_i).byte()
                    ind_old = ind_i & ind_old

                rel_inds_neg = rel_inds_neg.masked_select(ind_old.view(-1,1).expand(-1,3) == 1).view(-1,3)
                rel_rep_neg = all_rel_rep_neg.masked_select(Variable(ind_old.view(-1,1).expand(-1,51)) == 1).view(-1,51)
                pred_classes_argmax_neg = pred_classes_argmax_neg.view(-1,1)[ind_old.view(-1,1) == 1]
                rel_labels_pred_neg = all_rel_pred_neg.masked_select(ind_old.view(-1,1).expand(-1,4) == 1).view(-1,4)

                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).view(-1)  # not use squeeze()
                twod_inds_neg = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                obj_scores_neg = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds_neg] 
                obj_scores0_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,1]])
                obj_scores1_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,2]])
                all_score_neg = max_rel_score_neg * obj_scores0_neg * obj_scores1_neg
                # delete those triplet whose score is lower than pos triplets
                prob_score_neg = all_score_neg[all_score_neg.data > prob_score.data.min()] if (all_score_neg.data > prob_score.data.min()).sum() != 0 else all_score_neg


                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_flag = flag[sort_prob_inds].view(-1)  # can be used to check distribution of pos and neg
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable
                
                # positive triplet score
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable 
                # negative triplet score
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable

                # determine how many rows will be updated in rel_dists_neg
                pos_repeat = torch.zeros(1, 1)
                neg_repeat = torch.zeros(1, 1)
                for i in range(pos_exp.size(0)):
                    if ( neg_exp.data > pos_exp.data[i] ).sum() != 0:
                        int_part = (neg_exp.data > pos_exp.data[i]).sum()
                        temp_pos_inds = torch.ones(int_part) * i
                        pos_repeat =  torch.cat((pos_repeat, temp_pos_inds.view(-1,1)), 0)
                        temp_neg_inds = torch.arange(int_part)
                        neg_repeat = torch.cat((neg_repeat, temp_neg_inds.view(-1,1)), 0)
                    else:
                        temp_pos_inds = torch.ones(1)* i
                        pos_repeat =  torch.cat((pos_repeat, temp_pos_inds.view(-1,1)), 0)
                        temp_neg_inds = torch.arange(1)
                        neg_repeat = torch.cat((neg_repeat, temp_neg_inds.view(-1,1)), 0)

                """
                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.size(0))[:,None].expand_as(torch.Tensor(pos_exp.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                """
                result.pos = pos_exp[pos_repeat.cuda().long().view(-1)]
                result.neg = neg_exp[neg_repeat.cuda().long().view(-1)]
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                

                result.ratio = torch.ones(3).cuda()
                result.ratio[0] = result.ratio[0] * (sorted_flag.nonzero().min() / (prob_score.size(0) + all_score_neg.size(0)))
                result.ratio[1] = result.ratio[1] * (sorted_flag.nonzero().max() / (prob_score.size(0) + all_score_neg.size(0)))
                result.ratio[2] = result.ratio[2] * (prob_score.size(0) + all_score_neg.size(0))

                return result

            else:  # no gt_rel in rel_inds
                print("no gt_rel in rel_inds!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                ipdb.set_trace()
                # testing triplet proposal
                rel_cands = im_inds.data[:, None] == im_inds.data[None]
                # self relation = 0
                rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0
                # Require overlap for detection
                if self.require_overlap:
                    rel_cands = rel_cands & (bbox_overlaps(boxes.data, boxes.data) > 0)
                rel_cands = rel_cands.nonzero()
                if rel_cands.dim() == 0:
                    print("rel_cands.dim() == 0!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    rel_cands = im_inds.data.new(1, 2).fill_(0)
                rel_cands = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
                rel_labels_neg = rel_cands
                rel_inds_neg = rel_cands

                twod_inds_neg = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                obj_scores_neg = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds_neg]
                vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                subj_obj = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]]
                prod_rep_neg = subj_obj * vr_neg
                rel_dists_neg = self.rel_compress(prod_rep_neg)
                # negative overall score
                obj_scores0_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,1]])
                obj_scores1_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,2]])
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,1:].max(1)
                pred_classes_argmax_neg = pred_classes_argmax_neg + 1

                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).view(-1)  # not use squeeze()
                prob_score_neg = max_rel_score_neg * obj_scores0_neg * obj_scores1_neg

                result.pos = Variable(torch.zeros(prob_score_neg.size(0)).cuda())
                result.neg = prob_score_neg
                result.anchor = Variable(torch.zeros(prob_score_neg.size(0)).cuda())

                result.ratio = torch.ones(3,1).cuda()

                return result
        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
        
        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores,
                           result.rm_obj_preds, rel_inds[:, 1:],
                           rel_rep)
示例#18
0
文件: rel_model2.py 项目: ht014/lsbr
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        global_feature = self.global_embedding(result.fmap.detach())
        result.global_dists = self.global_logist(global_feature)
        # print(result.global_dists)
        # result.global_rel_dists = F.sigmoid(self.global_rel_logist(global_feature))

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, node_rep0 = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        one_hot_multi = torch.zeros(
            (result.global_dists.shape[0], self.num_classes))

        one_hot_multi[im_inds, result.rm_obj_labels] = 1.0
        result.multi_hot = one_hot_multi.float().cuda()
        edge_rep = node_rep0.repeat(1, 2)

        edge_rep = edge_rep.view(edge_rep.size(0), 2, -1)
        global_feature_re = global_feature[im_inds]
        subj_global_additive_attention = F.relu(
            self.global_sub_additive(edge_rep[:, 0] + global_feature_re))
        obj_global_additive_attention = F.relu(
            torch.sigmoid(
                self.global_obj_additive(edge_rep[:, 1] + global_feature_re)))

        subj_rep = edge_rep[:,
                            0] + subj_global_additive_attention * global_feature_re
        obj_rep = edge_rep[:,
                           1] + obj_global_additive_attention * global_feature_re

        edge_of_coordinate_rep = self.coordinate_feats(boxes.data, rel_inds)

        e_ij_coordinate_rep = self.edge_coordinate_embedding(
            edge_of_coordinate_rep)

        union_rep = self.visual_rep(result.fmap.detach(), rois, rel_inds[:,
                                                                         1:])
        edge_feat_init = union_rep

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[
            rel_inds[:, 2]] * edge_feat_init
        prod_rep = torch.cat((prod_rep, e_ij_coordinate_rep), 1)

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels, fg_rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        #if self.training and (not self.use_rl_tree):
            # generate arbitrary forest according to graph
        #    arbitrary_forest = graph_to_trees(self.co_occour, result.rel_labels, gt_classes)
        #else:
        arbitrary_forest = None

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        if self.use_rl_tree:
            result.rel_label_tkh = self.get_rel_label(im_inds, gt_rels, rel_inds)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # whole image feature, used for virtual node
        batch_size = result.fmap.shape[0]
        image_rois = Variable(torch.randn(batch_size, 5).fill_(0).cuda())
        for i in range(batch_size):
            image_rois[i, 0] = i
            image_rois[i, 1] = 0
            image_rois[i, 2] = 0
            image_rois[i, 3] = IM_SCALE
            image_rois[i, 4] = IM_SCALE
        image_fmap = self.obj_feature_map(result.fmap.detach(), image_rois)

        if self.mode != 'sgdet' and self.training:
            fg_rel_labels = result.rel_labels

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx, result.gen_tree_loss, result.entropy_loss, result.pair_gate, result.pair_gt = self.context(
            result.obj_fmap,
            result.rm_obj_dists.detach(),
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all, 
            arbitrary_forest,
            image_rois,
            image_fmap,
            self.co_occour,
            fg_rel_labels if self.training else None,
            x)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.hidden_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep =  torch.cat((subj_rep[rel_inds[:, 1]], obj_rep[rel_inds[:, 2]]), 1)
        prod_rep = self.post_cat(prod_rep)

        if self.use_encoded_box:
            # encode spatial info
            assert(boxes.shape[1] == 4)
            # encoded_boxes: [box_num, (x1,y1,x2,y2,cx,cy,w,h)]
            encoded_boxes = tree_utils.get_box_info(boxes)
            # encoded_boxes_pair: [batch_szie, (box1, box2, unionbox, intersectionbox)]
            encoded_boxes_pair = tree_utils.get_box_pair_info(encoded_boxes[rel_inds[:, 1]], encoded_boxes[rel_inds[:, 2]])
            # encoded_spatial_rep
            spatial_rep = F.relu(self.encode_spatial_2(F.relu(self.encode_spatial_1(encoded_boxes_pair))))
            # element-wise multiply with prod_rep
            prod_rep = prod_rep * spatial_rep

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.obj_preds[rel_inds[:, 1]],
                result.obj_preds[rel_inds[:, 2]],
            ), 1))

        if self.training and (not self.rl_train):
            return result

        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        if not self.rl_train:
            return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep, gt_boxes, gt_classes, gt_rels)
        else:
            return result, filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep, gt_boxes, gt_classes, gt_rels)
示例#20
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap,
            result.rm_obj_dists.detach(),
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.obj_preds[rel_inds[:, 1]],
                result.obj_preds[rel_inds[:, 2]],
            ), 1))

        if self.training:
            return result

        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep)
示例#21
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")
        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors.detach()   




        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)
            rel_labels_neg = self.get_neg_examples(result.rel_labels)
            rel_inds_neg = rel_labels_neg[:,:3]

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)  #[275,3], [im_inds, box1_inds, box2_inds]
        
        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois.detach()) # detach: prevent backforward flowing

        ############### Box Loss in BiLSTM ################
        #result.lstm_box_deltas = self.bbox_fc(result.rm_obj_fmap).view(-1, len(self.classes), 4)
        ############### Box Loss in BiLSTM ################


        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,   # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(),  # .detach:Returns a new Variable, detached from the current graph
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all.detach() if self.mode == 'sgdet' else result.boxes_all)
        

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else: 
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]
     
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  #[384,2,4096]
        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # prod_rep, rel_inds: [275,4096], [275,3]
    
        obj1 = self.obj1_fc(subj_rep[rel_inds[:, 1]])
        obj1 = obj1.view(obj1.size(0), self.num_classes, self.embdim)  # (275, 151, 10)
        obj2 = self.obj2_fc(obj_rep[rel_inds[:, 2]])
        obj2 = obj2.view(obj2.size(0), self.num_classes, self.embdim)  # (275, 151, 10)

        if self.training:
            prod_rep_neg = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]]
            obj1_neg = self.obj1_fc(subj_rep[rel_inds_neg[:, 1]])
            obj1_neg = obj1_neg.view(obj1_neg.size(0), self.num_classes, self.embdim)  # (275*self.neg_num, 151, 10)
            obj2_neg = self.obj2_fc(obj_rep[rel_inds_neg[:, 2]])
            obj2_neg = obj2_neg.view(obj2_neg.size(0), self.num_classes, self.embdim)  # (275*self.neg_num, 151, 10)

        if self.use_vision: # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            #vr = self.visual_rep(result.fmap.detach(), im_bboxes.detach(), rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) 
            else:
                prod_rep = prod_rep * vr  # [275,4096]
                rel_emb = self.rel_seq(prod_rep)
                rel_emb = rel_emb.view(rel_emb.size(0), self.num_rels, self.embdim)  # (275, 51, 10)
                if self.training:
                    vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                    prod_rep_neg = prod_rep_neg * vr_neg if self.training else None # [275*self.neg_num, 4096]
                    rel_emb_neg = self.rel_seq(prod_rep_neg)
                    

        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.rm_obj_preds[rel_inds[:, 1]],
                result.rm_obj_preds[rel_inds[:, 2]],
            ), 1))


        if self.training:
            # pos_exp: [275, 100] * self.neg_num
            twod_inds1 = arange(rel_inds[:, 1]) * self.num_classes + result.rm_obj_preds.data[rel_inds[:, 1]]
            twod_inds2 = arange(rel_inds[:, 2]) * self.num_classes + result.rm_obj_preds.data[rel_inds[:, 2]]
            rel_type = result.rel_labels[:, 3].data # [275]
            twod_inds_r = arange(rel_type) * self.num_rels + rel_type
            
            twod_inds1 = twod_inds1[:,None].expand_as(torch.Tensor(twod_inds1.size(0), self.neg_num)).contiguous().view(-1)
            twod_inds2 = twod_inds2[:,None].expand_as(torch.Tensor(twod_inds2.size(0), self.neg_num)).contiguous().view(-1)
            twod_inds_r = twod_inds_r[:,None].expand_as(torch.Tensor(twod_inds_r.size(0), self.neg_num)).contiguous().view(-1)
            result.pos = obj1.view(-1,self.embdim)[twod_inds1] + rel_emb.view(-1,self.embdim)[twod_inds_r] - obj2.view(-1,self.embdim)[twod_inds2]

            # neg_exp: [275 * self.neg_num, 100]
            twod_inds1_neg = arange(rel_inds_neg[:, 1]) * self.num_classes + result.rm_obj_preds.data[rel_inds_neg[:, 1]]
            twod_inds2_neg = arange(rel_inds_neg[:, 2]) * self.num_classes + result.rm_obj_preds.data[rel_inds_neg[:, 2]]
            rel_type_neg = rel_labels_neg[:, 3]  # [275 * neg_num]
            twod_inds_r_neg = arange(rel_type_neg) * self.num_rels + rel_type_neg
            result.neg = obj1_neg.view(-1,self.embdim)[twod_inds1_neg] + rel_emb_neg.view(-1,self.embdim)[twod_inds_r_neg] - obj2_neg.view(-1,self.embdim)[twod_inds2_neg]
            
            result.anchor = Variable(torch.zeros(result.pos.size(0), self.embdim).cuda())
            return result
        
        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]

        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores,
                           result.rm_obj_preds, rel_inds[:, 1:],
                           rel_rep, obj1, obj2, rel_emb)
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False,
                depth_imgs=None):
        """
        Forward pass for relation detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: a numpy array of (h, w, scale) for each image.
        :param image_offset: offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param gt_rels: [] gt relations
        :param proposals: region proposals retrieved from file
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :param return_fmap: if the object detector must return the extracted feature maps
        :param depth_imgs: depth images [batch_size, 1, IM_SIZE, IM_SIZE]
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        if self.has_visual:
            # -- Feed forward the rgb images to Faster-RCNN
            result = self.detector(x,
                                   im_sizes,
                                   image_offset,
                                   gt_boxes,
                                   gt_classes,
                                   gt_rels,
                                   proposals,
                                   train_anchor_inds,
                                   return_fmap=True)
        else:
            # -- Get prior `result` object (instead of calling faster-rcnn's detector)
            result = self.get_prior_results(image_offset, gt_boxes, gt_classes,
                                            gt_rels)

        # -- Get RoI and relations
        rois, rel_inds = self.get_rois_and_rels(result, image_offset, gt_boxes,
                                                gt_classes, gt_rels)
        boxes = result.rm_box_priors

        # -- Determine subject and object indices
        subj_inds = rel_inds[:, 1]
        obj_inds = rel_inds[:, 2]

        # -- Prepare object predictions vector (PredCLS)
        # replace with ground truth labels
        result.obj_preds = result.rm_obj_labels
        # replace with one-hot distribution of ground truth labels
        result.rm_obj_dists = F.one_hot(result.rm_obj_labels.data,
                                        self.num_classes).float()
        obj_cls = result.rm_obj_dists
        result.rm_obj_dists = result.rm_obj_dists * 1000 + (
            1 - result.rm_obj_dists) * (-1000)

        rel_features = []
        # -- Extract RGB features
        if self.has_visual:
            # Feed the extracted features from first conv layers to the last 'classifier' layers (VGG)
            # Here, only the last 3 layers of VGG are being trained. Everything else (in self.detector)
            # is frozen.
            result.obj_fmap = self.get_roi_features(result.fmap.detach(), rois)

            # -- Create a pairwise relation vector out of visual features
            rel_visual = torch.cat(
                (result.obj_fmap[subj_inds], result.obj_fmap[obj_inds]), 1)
            rel_visual_fc = self.visual_hlayer(rel_visual)
            rel_visual_scale = self.visual_scale(rel_visual_fc)
            rel_features.append(rel_visual_scale)

        # -- Extract Location features
        if self.has_loc:
            # -- Create a pairwise relation vector out of location features
            rel_location = self.get_loc_features(boxes, subj_inds, obj_inds)
            rel_location_fc = self.location_hlayer(rel_location)
            rel_location_scale = self.location_scale(rel_location_fc)
            rel_features.append(rel_location_scale)

        # -- Extract Class features
        if self.has_class:
            if self.use_embed:
                obj_cls = obj_cls @ self.obj_embed.weight
            # -- Create a pairwise relation vector out of class features
            rel_classme = torch.cat((obj_cls[subj_inds], obj_cls[obj_inds]), 1)
            rel_classme_fc = self.classme_hlayer(rel_classme)
            rel_classme_scale = self.classme_scale(rel_classme_fc)
            rel_features.append(rel_classme_scale)

        # -- Extract Depth features
        if self.has_depth:
            # -- Extract features from depth backbone
            depth_features = self.depth_backbone(depth_imgs)
            depth_rois_features = self.get_roi_features_depth(
                depth_features, rois)

            # -- Create a pairwise relation vector out of location features
            rel_depth = torch.cat((depth_rois_features[subj_inds],
                                   depth_rois_features[obj_inds]), 1)
            rel_depth_fc = self.depth_rel_hlayer(rel_depth)
            rel_depth_scale = self.depth_scale(rel_depth_fc)
            rel_features.append(rel_depth_scale)

        # -- Create concatenated feature vector
        rel_fusion = torch.cat(rel_features, 1)

        # -- Extract relation embeddings (penultimate layer)
        rel_embeddings = self.fusion_hlayer(rel_fusion)

        # -- Mix relation embeddings with UoBB features
        if self.has_visual and self.use_vision:
            uobb_features = self.get_union_features(result.fmap.detach(), rois,
                                                    rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                uobb_limit = int(self.hidden_dim / 2)
                rel_embeddings = torch.cat((rel_embeddings[:, :uobb_limit] *
                                            uobb_features[:, :uobb_limit],
                                            rel_embeddings[:, uobb_limit:]), 1)
            else:
                rel_embeddings = rel_embeddings * uobb_features

        # -- Predict relation distances
        result.rel_dists = self.rel_out(rel_embeddings)

        # -- Frequency bias
        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        # --- *** END OF ARCHITECTURE *** ---#

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # Filtering: Subject_Score * Pred_score * Obj_score, sorted and ranked
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_masks=None,
                gt_classes=None,
                gt_rels=None,
                pred_boxes=None,
                pred_masks=None,
                pred_fmaps=None,
                pred_dists=None):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

            if test:
            prob dists, boxes, img inds, maxscores, classes

        pred_fmaps  N*256*14*14
        pred_boxes  N*4
        pred_masks  N*28*28
        pred_dists  N*85

        """
        #print(pred_fmaps.shape, pred_boxes.shape, pred_masks.shape, pred_dists.shape)

        if self.training:
            im_inds = gt_classes[:, 0]
            rois = torch.cat((im_inds.float()[:, None], gt_boxes), 1)
            # actually is rel_assignment for sgcls
            # 指定rel的gt, roi不发生变化
            rois, labels, rel_labels = proposal_assignments_gtbox(
                rois.data, gt_boxes.data, gt_classes.data, gt_rels.data,
                image_offset)
            #boxes = rois[:, 1:]
            pred_boxes = rois[:, 1:]
            pred_masks = gt_masks
            pred_dists = Variable(to_onehot(labels.data, self.num_classes))
        else:
            im_inds = pred_boxes[:, 0].long()
            pred_boxes = pred_boxes[:, 1:]
            labels = gt_classes[:, 1]
            rel_labels = None
            pred_dists = Variable(
                to_onehot(pred_dists.data.long(), self.num_classes))
            rois = torch.cat((im_inds[:, None].float(), pred_boxes), 1)

        result = Result()
        #pred_fmaps = pred_fmaps * self.downsample(pred_masks[:, None, :, :])
        #result.obj_fmap = self.roi_fmap_obj(pred_fmaps.view(len(pred_fmaps), -1))
        result.obj_fmap = self.obj_feature_map(pred_fmaps, rois)
        result.rm_obj_dists = pred_dists
        result.rm_obj_labels = labels
        result.rel_labels = rel_labels
        #result.boxes_all = None
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, pred_boxes)
        #rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        # result.obj_fmap = self.obj_feature_map(result.fmap, rois)
        #  print(pred_fmaps[0][0][0])
        #  print(result.rm_obj_labels[0])
        #  print(result.rm_obj_dists[0][:10])
        #  print(pred_boxes.data[[0]])
        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap, result.rm_obj_dists, im_inds, result.rm_obj_labels
            if self.training or self.mode == 'predcls' else None,
            pred_boxes.data, None)

        #print(fdsafds)
        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]
        vr = self.visual_rep(pred_fmaps, rois, rel_inds[:, 1:])
        prod_rep = prod_rep * vr
        # if self.use_vision:
        #     vr = self.visual_rep(pred_fmaps, rois, rel_inds[:, 1:])
        #     if self.limit_vision:
        #         # exact value TBD
        #         prod_rep = torch.cat((prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
        #     else:
        #         prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # # Bbox regression
        # if self.mode == 'sgdet':
        #     bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        # else:
        #     # Boxes will get fixed by filter_dets function.
        #     bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets_mask(pred_boxes, pred_masks, result.obj_scores,
                                result.obj_preds, rel_inds[:, 1:], rel_rep)