def get_loc_features(boxes, subj_inds, obj_inds):
        """
        Calculate the scale-invariant location feature
        :param boxes: ground-truth/detected boxes
        :param subj_inds: subject indices
        :param obj_inds: object indices
        :return: location_feature
        """
        boxes_centered = center_size(boxes.data)

        # -- Determine box's center and size (subj's box)
        center_subj = boxes_centered[subj_inds][:, 0:2]
        size_subj = boxes_centered[subj_inds][:, 2:4]

        # -- Determine box's center and size (obj's box)
        center_obj = boxes_centered[obj_inds][:, 0:2]
        size_obj = boxes_centered[obj_inds][:, 2:4]

        # -- Calculate the scale-invariant location features of the subject
        t_coord_subj = (center_subj - center_obj) / size_obj
        t_size_subj = torch.log(size_subj / size_obj)

        # -- Calculate the scale-invariant location features of the object
        t_coord_obj = (center_obj - center_subj) / size_subj
        t_size_obj = torch.log(size_obj / size_subj)

        # -- Put everything together
        location_feature = Variable(
            torch.cat((t_coord_subj, t_size_subj, t_coord_obj, t_size_obj), 1))
        return location_feature
def get_box_info(boxes):
    """
    input: [batch_size, (x1,y1,x2,y2)]
    output: [batch_size, (x1,y1,x2,y2,cx,cy,w,h)]
    """
    return torch.cat(
        (boxes / float(IM_SCALE), center_size(boxes) / float(IM_SCALE)), 1)
Exemplo n.º 3
0
    def coordinate_feats(self, boxes, rel_inds):
        coordinate_rep = {}
        coordinate_rep['center'] = center_size(boxes)
        coordinate_rep['point'] = torch.cat(
            (boxes, coordinate_rep['center'][:, 2:]), 1)
        sub_coordnate = {}
        sub_coordnate['center'] = coordinate_rep['center'][rel_inds[:, 1]]
        sub_coordnate['point'] = coordinate_rep['point'][rel_inds[:, 1]]

        obj_coordnate = {}
        obj_coordnate['center'] = coordinate_rep['center'][rel_inds[:, 2]]
        obj_coordnate['point'] = coordinate_rep['point'][rel_inds[:, 2]]
        edge_of_coordinate_rep = torch.zeros(sub_coordnate['center'].size(0),
                                             5).cuda().float()
        edge_of_coordinate_rep[:, 0] = (sub_coordnate['point'][:, 0] - obj_coordnate['center'][:, 0]) * 1.0 / \
                                       obj_coordnate['center'][:, 2]
        edge_of_coordinate_rep[:, 1] = (sub_coordnate['point'][:, 1] - obj_coordnate['center'][:, 1]) * 1.0 / \
                                       obj_coordnate['center'][:, 3]
        edge_of_coordinate_rep[:, 2] = (sub_coordnate['point'][:, 2] - obj_coordnate['center'][:, 0]) * 1.0 / \
                                       obj_coordnate['center'][:, 2]
        edge_of_coordinate_rep[:, 3] = (sub_coordnate['point'][:, 3] - obj_coordnate['center'][:, 1]) * 1.0 / \
                                       obj_coordnate['center'][:, 3]
        edge_of_coordinate_rep[:, 4] = sub_coordnate['point'][:, 4] * sub_coordnate['point'][:, 5] * 1.0 / \
                                       obj_coordnate['center'][:, 2] \
                                       / obj_coordnate['center'][:, 3]
        return edge_of_coordinate_rep
Exemplo n.º 4
0
    def fuse_message(self, union_box_feats, boxes, box_classes, rel_inds):
        """Fuse union Appearance features, box spatial information and NLP word features together
        Args:
            union_box_feats: Variable
            boxes: Variable
            box_classes: Variable
            rel_inds: Variable
        Returns:
            box_pair_feats:
        """
        bboxes = Variable(center_size(boxes.data))
        sub_bboxes = bboxes[rel_inds[:, 1].contiguous()]
        obj_bboxes = bboxes[rel_inds[:, 2].contiguous()]

        obj_bboxes[:, :2] = obj_bboxes[:, :2].contiguous(
        ) - sub_bboxes[:, :2].contiguous()  # x-y
        obj_bboxes[:, 2:] = obj_bboxes[:, 2:].contiguous(
        ) / sub_bboxes[:, 2:].contiguous()  # w/h
        obj_bboxes[:, :2] /= sub_bboxes[:, 2:].contiguous()  # x-y/h
        obj_bboxes[:, 2:] = torch.log(obj_bboxes[:,
                                                 2:].contiguous())  # log(w/h)

        bbox_spatial_feats = self.spatial_fc(obj_bboxes)

        box_word = self.classes_word_embedding(box_classes)
        box_pair_word = torch.cat((box_word[rel_inds[:, 1].contiguous()],
                                   box_word[rel_inds[:, 2].contiguous()]), 1)
        box_word_feats = self.word_fc(box_pair_word)

        # (NumOfRels, DIM=)
        box_pair_feats = torch.cat(
            (union_box_feats, bbox_spatial_feats, box_word_feats), 1)
        return box_pair_feats
Exemplo n.º 5
0
    def forward(self,
                obj_fmaps,
                obj_logits,
                im_inds,
                obj_labels=None,
                box_priors=None,
                boxes_per_cls=None):
        """
        Forward pass through the object and edge context
        :param obj_priors:
        :param obj_fmaps:
        :param im_inds:
        :param obj_labels:
        :param boxes:
        :return:
        """

        obj_embed = F.softmax(obj_logits, dim=1) @ self.obj_embed.weight

        pos_embed = self.pos_embed(center_size(box_priors))
        # obj_pre_rep = self.conver_fusion_feature(torch.cat((obj_fmaps, obj_embed, pos_embed), 1))
        obj_pre_rep = self.conver_fusion_feature(
            torch.cat((obj_embed, pos_embed), 1))
        # UNSURE WHAT TO DO HERE
        if self.mode == 'predcls':
            obj_dists2 = Variable(to_onehot(obj_labels.data, self.num_classes))
        else:
            obj_dists2 = self.decoder_lin(obj_pre_rep)

        if self.mode == 'sgdet' and not self.training:
            # NMS here for baseline
            probs = F.softmax(obj_dists2, 1)
            nms_mask = obj_dists2.data.clone()
            nms_mask.zero_()
            for c_i in range(1, obj_dists2.size(1)):
                scores_ci = probs.data[:, c_i]
                boxes_ci = boxes_per_cls.data[:, c_i]

                keep = apply_nms(scores_ci,
                                 boxes_ci,
                                 pre_nms_topn=scores_ci.size(0),
                                 post_nms_topn=scores_ci.size(0),
                                 nms_thresh=0.3)
                nms_mask[:, c_i][keep] = 1

            obj_preds = Variable(nms_mask * probs.data,
                                 volatile=True)[:, 1:].max(1)[1] + 1
        else:
            obj_preds = obj_labels if obj_labels is not None else obj_dists2[:, 1:].max(
                1)[1] + 1

        return obj_dists2, obj_preds, obj_pre_rep
Exemplo n.º 6
0
    def roi_proposals(self,
                      fmap,
                      im_sizes,
                      nms_thresh=0.7,
                      pre_nms_topn=12000,
                      post_nms_topn=2000):
        """
        :param fmap: [batch_size, IM_SIZE/16, IM_SIZE/16, A, 6]
        :param im_sizes:        [batch_size, 3] numpy array of (h, w, scale)
        :return: ROIS: shape [a <=post_nms_topn, 5] array of ROIS.
        """
        # print("*** RPNHead.roi_proposals ***")
        # print("pre_nms_topn", pre_nms_topn) # 6000
        # print("post_nms_topn", post_nms_topn) # 1000

        class_fmap = fmap[:, :, :, :, :2].contiguous()

        # GET THE GOOD BOXES AYY LMAO :')
        class_preds = F.softmax(class_fmap, 4)[..., 1].data.contiguous()

        box_fmap = fmap[:, :, :, :, 2:].data.contiguous()

        anchor_stacked = torch.cat([self.anchors[None]] * fmap.size(0), 0)
        box_preds = bbox_preds(anchor_stacked.view(-1, 4),
                               box_fmap.view(-1, 4)).view(*box_fmap.size())

        for i, (h, w, scale) in enumerate(im_sizes):
            # Zero out all the bad boxes h, w, A, 4
            h_end = int(h) // self.stride
            w_end = int(w) // self.stride
            if h_end < class_preds.size(1):
                class_preds[i, h_end:] = -0.01
            if w_end < class_preds.size(2):
                class_preds[i, :, w_end:] = -0.01

            # and clamp the others
            box_preds[i, :, :, :, 0].clamp_(min=0, max=w - 1)
            box_preds[i, :, :, :, 1].clamp_(min=0, max=h - 1)
            box_preds[i, :, :, :, 2].clamp_(min=0, max=w - 1)
            box_preds[i, :, :, :, 3].clamp_(min=0, max=h - 1)

        sizes = center_size(box_preds.view(-1, 4))
        class_preds.view(-1)[(sizes[:, 2] < 4) | (sizes[:, 3] < 4)] = -0.01
        return filter_roi_proposals(
            box_preds.view(-1, 4),
            class_preds.view(-1),
            boxes_per_im=np.array([np.prod(box_preds.size()[1:-1])] *
                                  fmap.size(0)),
            nms_thresh=nms_thresh,
            pre_nms_topn=pre_nms_topn,
            post_nms_topn=post_nms_topn)
Exemplo n.º 7
0
    def add_noise(self, rois, iou, im_sizes):
        noise_pixels = torch.from_numpy((np.random.rand(rois.size(0), 2) - 0.5) * 2.0).cuda().float()
        boxes = rois[:, 1:].data.clone()
        c_boxes = center_size(boxes)
        h, w = c_boxes[:, 2], c_boxes[:, 3]
        delta = ((1.0 + iou) * (h + w) - torch.sqrt(((1.0 + iou) * (h + w)) ** 2 - 4 * ((1.0 - iou) ** 2) * h * w)) / (
                    2 * (1.0 - iou))
        c_boxes[:, :2] += noise_pixels * delta.unsqueeze(-1)
        p_boxes = point_form(c_boxes)
        H, W = im_sizes[0, :2]
        p_boxes[:, 0].clamp_(min=0, max=W - 1)
        p_boxes[:, 1].clamp_(min=0, max=H - 1)
        p_boxes[:, 2].clamp_(min=0, max=W - 1)
        p_boxes[:, 3].clamp_(min=0, max=H - 1)

        rois[:, 1:].data.copy_(p_boxes)

        return rois
Exemplo n.º 8
0
 def sort_rois(self, batch_idx, confidence, box_priors):
     """
     :param batch_idx: tensor with what index we're on
     :param confidence: tensor with confidences between [0,1)
     :param boxes: tensor with (x1, y1, x2, y2)
     :return: Permutation, inverse permutation, and the lengths transposed (same as _sort_by_score)
     """
     cxcywh = center_size(box_priors)
     if self.order == 'size':
         sizes = cxcywh[:,2] * cxcywh[:, 3]
         # sizes = (box_priors[:, 2] - box_priors[:, 0] + 1) * (box_priors[:, 3] - box_priors[:, 1] + 1)
         assert sizes.min() > 0.0
         scores = sizes / (sizes.max() + 1)
     elif self.order == 'confidence':
         scores = confidence
     elif self.order == 'random':
         scores = torch.FloatTensor(np.random.rand(batch_idx.size(0))).cuda(batch_idx.get_device())
     elif self.order == 'leftright':
         centers = cxcywh[:,0]
         scores = centers / (centers.max() + 1)
     else:
         raise ValueError("invalid mode {}".format(self.order))
     return _sort_by_score(batch_idx, scores)
Exemplo n.º 9
0
    def geo_layout_enc(self, box_priors, rel_inds):
        """
        Geometric Layout Encoding
        :param box_priors: [num_rois, 4] of (xmin, ymin, xmax, ymax)
        :param rel_inds: [num_rels, 3] of (img ind, box0 ind, box1 ind)
        :return: bos: [num_rois*(num_rois-1), 4] encoded relative geometric layout: bo|s
        """
        cxcywh = center_size(box_priors.data)  # convert to (cx, cy, w, h)
        box_s = cxcywh[rel_inds[:, 1]]
        box_o = cxcywh[rel_inds[:, 2]]

        # relative location
        rlt_loc_x = torch.div((box_o[:, 0] - box_s[:, 0]),
                              box_s[:, 2]).view(-1, 1)
        rlt_loc_y = torch.div((box_o[:, 1] - box_s[:, 1]),
                              box_s[:, 3]).view(-1, 1)

        # scale information
        scl_info_w = torch.log(torch.div(box_o[:, 2], box_s[:, 2])).view(-1, 1)
        scl_info_h = torch.log(torch.div(box_o[:, 3], box_s[:, 3])).view(-1, 1)

        bos = torch.cat((rlt_loc_x, rlt_loc_y, scl_info_w, scl_info_h), 1)
        return bos
for img_i in trange(len(val)):
    gt_entry = {
        'gt_classes': val.gt_classes[img_i].copy(),
        'gt_relations': val.relationships[img_i].copy(),
        'gt_boxes': val.gt_boxes[img_i].copy(),
    }

    # Use shuffled GT boxes
    gt_indices = np.arange(
        gt_entry['gt_boxes'].shape[0]
    )  #np.random.choice(gt_entry['gt_boxes'].shape[0], 20)
    pred_boxes = gt_entry['gt_boxes'][gt_indices]

    # Jitter the boxes a bit
    pred_boxes = center_size(pred_boxes)
    pred_boxes[:, :2] += np.random.rand(pred_boxes.shape[0], 2) * 128
    pred_boxes[:,
               2:] *= (1 +
                       np.random.randn(pred_boxes.shape[0], 2).clip(-0.1, 0.1))
    pred_boxes = point_form(pred_boxes)

    obj_scores = np.random.rand(pred_boxes.shape[0])

    rels_to_use = np.column_stack(
        np.where(1 - np.diag(np.ones(pred_boxes.shape[0], dtype=np.int32))))
    rel_scores = np.random.rand(min(100, rels_to_use.shape[0]), 51)
    rel_scores = rel_scores / rel_scores.sum(1, keepdims=True)
    pred_rel_inds = rels_to_use[np.random.choice(rels_to_use.shape[0],
                                                 rel_scores.shape[0],
                                                 replace=False)]
Exemplo n.º 11
0
    def forward(self,
                obj_fmaps,
                obj_logits,
                im_inds,
                obj_labels=None,
                box_priors=None,
                boxes_per_cls=None):
        """
        Forward pass through the object and edge context
        :param obj_priors: from faster rcnn output boxes
        :param obj_fmaps: 4096-dim roi feature maps
        :param obj_logits: result.rm_obj_dists.detach()
        :param im_inds:
        :param obj_labels: od_obj_labels, gt
        :param boxes:
        :return: obj_dists2: [#boxes, 151], new score for boxes
                 obj_preds: [#boxes], prediction/class value
                 edge_ctx: [#boxes, 512], new features for boxes

        """

        # Object State:
        # obj_embed: [#boxes, 200], and self.obj_embed.weight are both Variable
        # obj_logits: result.rm_obj_dists.detach(), [#boxes, 151], detector scores before softmax
        obj_embed = F.softmax(obj_logits, dim=1) @ self.obj_embed.weight
        # center_size returns boxes as (center_x, center_y, width, height)
        # pos_embed: [#boxes, 128], Variable, from boxes after Sequential processing
        pos_embed = self.pos_embed(Variable(center_size(box_priors)))
        # obj_pre_rep: [#boxes, 4424], Variable
        obj_pre_rep = torch.cat((obj_fmaps, obj_embed, pos_embed), 1)

        if self.nl_obj > 0:
            # obj_dists2: [#boxes, 151], new score for box
            # obj_preds: [#boxes], prediction/class value
            # obj_ctx: [#boxes, 512], new features vector for box
            obj_dists2, obj_preds, obj_ctx = self.obj_ctx(
                obj_pre_rep,  #obj_fmaps,  # original: obj_pre_rep,
                obj_logits,
                im_inds,
                obj_labels,
                box_priors,
                boxes_per_cls,
            )
        else:
            # UNSURE WHAT TO DO HERE
            if self.mode == 'predcls':
                obj_dists2 = Variable(
                    to_onehot(obj_labels.data, self.num_classes))
            else:
                obj_dists2 = self.decoder_lin(obj_pre_rep)

            if self.mode == 'sgdet' and not self.training:
                # NMS here for baseline

                probs = F.softmax(obj_dists2, 1)
                nms_mask = obj_dists2.data.clone()
                nms_mask.zero_()
                for c_i in range(1, obj_dists2.size(1)):
                    scores_ci = probs.data[:, c_i]
                    boxes_ci = boxes_per_cls.data[:, c_i]

                    keep = apply_nms(scores_ci,
                                     boxes_ci,
                                     pre_nms_topn=scores_ci.size(0),
                                     post_nms_topn=scores_ci.size(0),
                                     nms_thresh=0.3)
                    nms_mask[:, c_i][keep] = 1

                obj_preds = Variable(nms_mask * probs.data,
                                     volatile=True)[:, 1:].max(1)[1] + 1
            else:
                obj_preds = obj_labels if obj_labels is not None else obj_dists2[:, 1:].max(
                    1)[1] + 1
            obj_ctx = obj_pre_rep

        # Edge State:
        edge_ctx = None

        if self.nl_edge > 0:
            # edge_ctx: [#boxes, 512]
            edge_ctx = self.edge_ctx(
                torch.cat((obj_fmaps, obj_ctx), 1)
                if self.pass_in_obj_feats_to_edge else obj_ctx,
                obj_dists=obj_dists2.detach(),  # Was previously obj_logits.
                im_inds=im_inds,
                obj_preds=obj_preds,
                box_priors=box_priors,
            )

        return obj_dists2, obj_preds, edge_ctx
Exemplo n.º 12
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

            if test:
            prob dists, boxes, img inds, maxscores, classes

        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        # rel_feat = self.relationship_feat.feature_map(x)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        spt_feats = self.get_boxes_encode(boxes, rel_inds)
        pair_inds = self.union_pairs(im_inds)

        if self.hook_for_grad:
            rel_inds = gt_rels[:, :-1].data

        if self.hook_for_grad:
            fmap = result.fmap
            fmap.register_hook(self.save_grad)
        else:
            fmap = result.fmap.detach()

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(fmap, rois)
        # result.obj_dists_head = self.obj_classify_head(obj_fmap_rel)

        obj_embed = F.softmax(result.rm_obj_dists,
                              dim=1) @ self.obj_embed.weight
        obj_embed_lstm = F.softmax(result.rm_obj_dists,
                                   dim=1) @ self.embeddings4lstm.weight
        pos_embed = self.pos_embed(Variable(center_size(boxes.data)))
        obj_pre_rep = torch.cat((result.obj_fmap, obj_embed, pos_embed), 1)
        obj_feats = self.merge_obj_feats(obj_pre_rep)
        # obj_feats=self.trans(obj_feats)
        obj_feats_lstm = torch.cat(
            (obj_feats, obj_embed_lstm),
            -1).contiguous().view(1, obj_feats.size(0), -1)

        # obj_feats = F.relu(obj_feats)

        phr_ori = self.visual_rep(fmap, rois, pair_inds[:, 1:])
        vr_indices = torch.from_numpy(
            intersect_2d(rel_inds[:, 1:].cpu().numpy(),
                         pair_inds[:, 1:].cpu().numpy()).astype(
                             np.uint8)).cuda().max(-1)[1]
        vr = phr_ori[vr_indices]

        phr_feats_high = self.get_phr_feats(phr_ori)

        obj_feats_lstm_output, (obj_hidden_states,
                                obj_cell_states) = self.lstm(obj_feats_lstm)

        rm_obj_dists1 = result.rm_obj_dists + self.context.decoder_lin(
            obj_feats_lstm_output.squeeze())
        obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \
                            phr_feats_high, im_inds, pair_inds)

        obj_embed_lstm1 = F.softmax(rm_obj_dists1,
                                    dim=1) @ self.embeddings4lstm.weight

        obj_feats_lstm1 = torch.cat((obj_feats_output, obj_embed_lstm1), -1).contiguous().view(1, \
                            obj_feats_output.size(0), -1)
        obj_feats_lstm_output, _ = self.lstm(
            obj_feats_lstm1, (obj_hidden_states, obj_cell_states))

        rm_obj_dists2 = rm_obj_dists1 + self.context.decoder_lin(
            obj_feats_lstm_output.squeeze())
        obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \
                            phr_feats_high, im_inds, pair_inds)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds = self.context(
            rm_obj_dists2, obj_feats_output, result.rm_obj_labels
            if self.training or self.mode == 'predcls' else None, boxes.data,
            result.boxes_all)

        obj_dtype = result.obj_fmap.data.type()
        obj_preds_embeds = torch.index_select(self.ort_embedding, 0,
                                              result.obj_preds).type(obj_dtype)
        tranfered_boxes = torch.stack(
            (boxes[:, 0] / IM_SCALE, boxes[:, 3] / IM_SCALE,
             boxes[:, 2] / IM_SCALE, boxes[:, 1] / IM_SCALE,
             ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])) /
             (IM_SCALE**2)), -1).type(obj_dtype)
        obj_features = torch.cat(
            (result.obj_fmap, obj_preds_embeds, tranfered_boxes), -1)
        obj_features_merge = self.merge_obj_low(
            obj_features) + self.merge_obj_high(obj_feats_output)

        # Split into subject and object representations
        result.subj_rep = self.post_emb_s(obj_features_merge)[rel_inds[:, 1]]
        result.obj_rep = self.post_emb_o(obj_features_merge)[rel_inds[:, 2]]
        prod_rep = result.subj_rep * result.obj_rep

        # obj_pools = self.visual_obj(result.fmap.detach(), rois, rel_inds[:, 1:])
        # rel_pools = self.relationship_feat.union_rel_pooling(rel_feat, rois, rel_inds[:, 1:])
        # context_pools = torch.cat([obj_pools, rel_pools], 1)
        # merge_pool = self.merge_feat(context_pools)
        # vr = self.roi_fmap(merge_pool)

        # vr = self.rel_refine(vr)

        prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        prod_rep = torch.cat((prod_rep, spt_feats), -1)
        freq_gate = self.freq_gate(prod_rep)
        freq_gate = F.sigmoid(freq_gate)
        result.rel_dists = self.rel_compress(prod_rep)
        # result.rank_factor = self.ranking_module(prod_rep).view(-1)

        if self.use_bias:
            result.rel_dists = result.rel_dists + freq_gate * self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # rel_rep = smooth_one_hot(rel_rep)
        # rank_factor = F.sigmoid(result.rank_factor)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
    def forward(self, obj_fmaps, obj_logits, im_inds, obj_labels=None, box_priors=None, boxes_per_cls=None, gt_forest=None, image_rois=None, image_fmap=None, co_occour=None, rel_labels=None, origin_img=None):
        """
        Forward pass through the object and edge context
        :param obj_priors: [obj_num, (x1,y1,x2,y2)], float cuda
        :param obj_fmaps:
        :param im_inds: [obj_num] long variable
        :param obj_labels:
        :param boxes:
        :return:
        """
        if self.mode == 'predcls':
            obj_logits = Variable(to_onehot(obj_labels.data, self.num_classes))
            
        obj_embed = F.softmax(obj_logits, dim=1) @ self.obj_embed.weight
        
        batch_size = image_rois.shape[0]
        # pseudo box and image index: to encode virtual node into original inputs
        pseudo_box_priors = torch.cat((box_priors, image_rois[:, 1:].contiguous().data), 0)  # [obj_num + batch_size, 4]
        pseudo_im_inds = torch.cat((im_inds, image_rois[:,0].contiguous().long().view(-1)), 0) # [obj_num + batch_size]
        pseudo_obj_fmaps = torch.cat((obj_fmaps.clone().detach(), image_fmap.detach()), 0)  # [obj_num + batch_size, 4096]
        virtual_embed = self.virtual_node_embed.weight[0].view(1, -1).expand(batch_size, -1)
        pseudo_obj_embed = torch.cat((obj_embed, virtual_embed), 0) # [obj_num + batch_size, embed_dim]
        if self.training or (self.mode == 'predcls'):
            pseudo_obj_labels = torch.cat((obj_labels, Variable(torch.randn(1).fill_(0).cuda()).expand(batch_size).long().view(-1)), 0)
        else:
            pseudo_obj_labels = None
        
        if self.mode == 'sgdet':
            obj_distributions = F.softmax(obj_logits, dim=1)[:,1:]
        else:
            obj_distributions = F.softmax(obj_logits[:,1:], dim=1)
        pseudo_obj_distributions = torch.cat((obj_distributions, Variable(torch.randn(batch_size, obj_distributions.shape[1]).fill_(0).cuda())), 0)
        # generate RL gen tree input
        box_embed = tree_utils.get_box_info(Variable(pseudo_box_priors)) # 8-digits
        overlap_embed, _ = tree_utils.get_overlap_info(pseudo_im_inds, Variable(pseudo_box_priors)) # 4-digits
        prepro_feat = self.feat_preprocess_net(pseudo_obj_fmaps, pseudo_obj_embed, box_embed, overlap_embed)
        pair_scores, pair_rel_gate, pair_rel_gt = self.rl_score_net(prepro_feat, pseudo_obj_distributions, co_occour, rel_labels, batch_size, im_inds, pseudo_im_inds)

        #print('node_scores', node_scores.data.cpu().numpy())
        arbitrary_forest, gen_tree_loss, entropy_loss = gen_tree.generate_forest(pseudo_im_inds, gt_forest, pair_scores, Variable(pseudo_box_priors), pseudo_obj_labels, self.use_rl_tree, self.training, self.mode)
        forest = arbitraryForest_to_biForest(arbitrary_forest)

        pseudo_pos_embed = self.pos_embed(Variable(center_size(pseudo_box_priors)))
        obj_pre_rep = torch.cat((pseudo_obj_fmaps, pseudo_obj_embed, pseudo_pos_embed), 1)
        if self.nl_obj > 0:
            obj_dists2, obj_preds, obj_ctx = self.obj_ctx(
                obj_pre_rep,
                pseudo_obj_labels,
                pseudo_box_priors,
                boxes_per_cls,
                forest,
                batch_size
            )
        else:
            print('Error, No obj ctx')

        edge_ctx = None
        if self.nl_edge > 0:
            edge_ctx = self.edge_ctx(
                torch.cat((pseudo_obj_fmaps, obj_ctx), 1) if self.pass_in_obj_feats_to_edge else obj_ctx,
                obj_preds=obj_preds,
                box_priors=pseudo_box_priors,
                forest = forest,
            )

        # draw tree
        if self.draw_tree and (self.draw_tree_count < self.draw_tree_max):
            for tree_idx in range(len(forest)):
                draw_tree_region(forest[tree_idx], origin_img, self.draw_tree_count)
                draw_tree_region_v2(forest[tree_idx], origin_img, self.draw_tree_count, obj_preds)
                self.draw_tree_count += 1

        # remove virtual nodes
        return obj_dists2, obj_preds[:-batch_size], edge_ctx[:-batch_size], gen_tree_loss, entropy_loss, pair_rel_gate, pair_rel_gt
Exemplo n.º 14
0

        
Exemplo n.º 15
0
    def forward(self, obj_fmaps, obj_logits, im_inds, obj_labels=None, box_priors=None, boxes_per_cls=None, batch_size=None,
                rois=None, od_box_deltas=None, im_sizes=None, image_offset=None, gt_classes=None, gt_boxes=None, ):
        """
        Forward pass through the object and edge context
        :param obj_priors:
        :param obj_fmaps:
        :param im_inds:
        :param obj_labels:
        :param boxes:
        :return:
        """
        obj_embed = F.softmax(obj_logits, dim=1) @ self.obj_embed.weight
        pos_embed = self.pos_embed(Variable(center_size(box_priors)))
        obj_pre_rep = torch.cat((obj_fmaps, obj_embed, pos_embed), 1)


        if self.mode == 'predcls':
            obj_dists2 = Variable(to_onehot(obj_labels.data, self.num_classes))
        else:
            if self.mode == 'sgcls':

                obj_dists2 = self.decoder_lin1(obj_pre_rep)
                obj_dists2 = self.decoder_lin2(obj_dists2.view(-1, 1, 1024), 1)

                obj_dists2 = obj_dists2[1]

                obj_dists2 = self.decoder_lin3(obj_dists2.view(-1, 1024))

            else:
                # this is for sgdet

                obj_dists2 = self.decoder_lin1(obj_pre_rep)

                perm, inv_perm, ls_transposed = self.sort_rois(im_inds.data, None, box_priors)
                obj_dists2 = obj_dists2[perm].contiguous()
                obj_dists2 = PackedSequence(obj_dists2, torch.tensor(ls_transposed))
                obj_dists2, lengths1 = pad_packed_sequence(obj_dists2, batch_first=False)


                obj_dists2 = self.decoder_lin2(obj_dists2.view(-1, batch_size, 1024), batch_size)[1]


                obj_dists2, _ = pack_padded_sequence(obj_dists2, lengths1, batch_first=False)
                obj_dists2 = self.decoder_lin3(obj_dists2.view(-1, 1024))
                obj_dists2 = obj_dists2[inv_perm]


                if (not self.training and not self.mode == 'gtbox') or self.mode in ('sgdet', 'refinerels'):
                    # try: dont apply nms here, but after own obj_classifier
                    nms_inds, nms_scores, nms_preds, nms_boxes_assign, nms_boxes, nms_imgs = self.nms_boxes(
                        obj_dists2.clone().detach(),
                        rois,
                        od_box_deltas.clone().detach(), im_sizes,
                    )
                    im_inds = nms_imgs + image_offset
                    obj_dists2 = obj_dists2[nms_inds]
                    obj_fmap = obj_fmaps[nms_inds]
                    box_deltas = od_box_deltas[nms_inds]
                    box_priors = nms_boxes[:, 0]
                    rois = rois[nms_inds]

                    if self.training and not self.mode == 'gtbox':
                        # NOTE: If we're doing this during training, we need to assign labels here.
                        pred_to_gtbox = bbox_overlaps(box_priors, gt_boxes).data
                        pred_to_gtbox[im_inds.data[:, None] != gt_classes.data[None, :, 0]] = 0.0

                        max_overlaps, argmax_overlaps = pred_to_gtbox.max(1)
                        rm_obj_labels = gt_classes[:, 1][argmax_overlaps]
                        rm_obj_labels[max_overlaps < 0.5] = 0
                    else:
                        rm_obj_labels = None

        if self.mode == 'sgdet' and not self.training:  # have tried in training
            # NMS here for baseline

            probs = F.softmax(obj_dists2, 1)
            nms_mask = obj_dists2.data.clone()
            nms_mask.zero_()
            for c_i in range(1, obj_dists2.size(1)):
                scores_ci = probs.data[:, c_i]
                boxes_ci = nms_boxes.data[:, c_i]

                keep = apply_nms(scores_ci, boxes_ci,
                                 pre_nms_topn=scores_ci.size(0), post_nms_topn=scores_ci.size(0),
                                 nms_thresh=0.5)#nms_thresh= 0.3 default
                nms_mask[:, c_i][keep] = 1

            obj_preds = Variable(nms_mask * probs.data, volatile=True)[:, 1:].max(1)[1] + 1  # this for sgdet test

            #obj_preds=obj_dists2[:,1:].max(1)[1] + 1
        else:
            if self.mode == 'sgdet':
                # use gt
                obj_preds = rm_obj_labels if rm_obj_labels is not None else obj_dists2[:, 1:].max(1)[1] + 1
                # use_predicted label
                # obj_preds = obj_dists2[:, 1:].max(1)[1] + 1
            else:
                obj_preds = obj_labels if obj_labels is not None else obj_dists2[:, 1:].max(1)[1] + 1

        if self.mode == 'sgdet':
            return obj_dists2, obj_preds, im_inds, box_priors, rm_obj_labels, rois, nms_boxes
        else:
            return obj_dists2, obj_preds
Exemplo n.º 16
0
    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for Relation detection
        Args:
            x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
            im_sizes: A numpy array of (h, w, scale) for each image.
            image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)

            parameters for training:
            gt_boxes: [num_gt, 4] GT boxes over the batch.
            gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
            gt_rels:
            proposals:
            train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
            return_fmap:

        Returns:
            If train:
                scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            If test:
                prob dists, boxes, img inds, maxscores, classes
        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)

        assert not result.is_none(), 'Empty detection result'

        # image_offset refer to Blob
        # self.batch_size_per_gpu * index
        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors
        obj_scores, box_classes = F.softmax(
            result.rm_obj_dists[:, 1:].contiguous(), dim=1).max(1)
        box_classes += 1

        num_img = im_inds[-1] + 1

        # embed(header='rel_model.py before rel_assignments')
        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'

            # only in sgdet mode

            # shapes:
            # im_inds: (box_num,)
            # boxes: (box_num, 4)
            # rm_obj_labels: (box_num,)
            # gt_boxes: (box_num, 4)
            # gt_classes: (box_num, 2) maybe[im_ind, class_ind]
            # gt_rels: (rel_num, 4)
            # image_offset: integer
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # union boxes feats (NumOfRels, obj_dim)
        union_box_feats = self.visual_rep(result.fmap.detach(), rois,
                                          rel_inds[:, 1:].contiguous())
        # single box feats (NumOfBoxes, feats)
        box_feats = self.obj_feature_map(result.fmap.detach(), rois)
        # box spatial feats (NumOfBox, 4)
        bboxes = Variable(center_size(boxes.data))
        sub_bboxes = bboxes[rel_inds[:, 1].contiguous()]
        obj_bboxes = bboxes[rel_inds[:, 2].contiguous()]

        obj_bboxes[:, :2] = obj_bboxes[:, :2].contiguous(
        ) - sub_bboxes[:, :2].contiguous()  # x-y
        obj_bboxes[:, 2:] = obj_bboxes[:, 2:].contiguous(
        ) / sub_bboxes[:, 2:].contiguous()  # w/h
        obj_bboxes[:, :2] /= sub_bboxes[:, 2:].contiguous()  # x-y/h
        obj_bboxes[:, 2:] = torch.log(obj_bboxes[:,
                                                 2:].contiguous())  # log(w/h)

        bbox_spatial_feats = self.spatial_fc(obj_bboxes)

        box_word = self.classes_word_embedding(box_classes)
        box_pair_word = torch.cat((box_word[rel_inds[:, 1].contiguous()],
                                   box_word[rel_inds[:, 2].contiguous()]), 1)
        box_word_feats = self.word_fc(box_pair_word)

        # (NumOfRels, DIM=)
        box_pair_feats = torch.cat(
            (union_box_feats, bbox_spatial_feats, box_word_feats), 1)

        box_pair_score = self.relpn_fc(box_pair_feats)
        #embed(header='filter_rel_labels')
        if self.training:
            pn_rel_label = list()
            pn_pair_score = list()
            #print(result.rel_labels.shape)
            #print(result.rel_labels[:, 0].contiguous().squeeze())
            for i, s, e in enumerate_by_image(
                    result.rel_labels[:, 0].data.contiguous()):
                im_i_rel_label = result.rel_labels[s:e].contiguous()
                im_i_box_pair_score = box_pair_score[s:e].contiguous()

                im_i_rel_fg_inds = torch.nonzero(
                    im_i_rel_label[:, -1].contiguous()).squeeze()
                im_i_rel_fg_inds = im_i_rel_fg_inds.data.cpu().numpy()
                im_i_fg_sample_num = min(RELEVANT_PER_IM,
                                         im_i_rel_fg_inds.shape[0])
                if im_i_rel_fg_inds.size > 0:
                    im_i_rel_fg_inds = np.random.choice(
                        im_i_rel_fg_inds,
                        size=im_i_fg_sample_num,
                        replace=False)

                im_i_rel_bg_inds = torch.nonzero(
                    im_i_rel_label[:, -1].contiguous() == 0).squeeze()
                im_i_rel_bg_inds = im_i_rel_bg_inds.data.cpu().numpy()
                im_i_bg_sample_num = min(EDGES_PER_IM - im_i_fg_sample_num,
                                         im_i_rel_bg_inds.shape[0])
                if im_i_rel_bg_inds.size > 0:
                    im_i_rel_bg_inds = np.random.choice(
                        im_i_rel_bg_inds,
                        size=im_i_bg_sample_num,
                        replace=False)

                #print('{}/{} fg/bg in image {}'.format(im_i_fg_sample_num, im_i_bg_sample_num, i))
                result.rel_sample_pos = torch.Tensor(
                    [im_i_fg_sample_num]).cuda(im_i_rel_label.get_device())
                result.rel_sample_neg = torch.Tensor(
                    [im_i_bg_sample_num]).cuda(im_i_rel_label.get_device())

                im_i_keep_inds = np.append(im_i_rel_fg_inds, im_i_rel_bg_inds)
                im_i_pair_score = im_i_box_pair_score[
                    im_i_keep_inds.tolist()].contiguous()

                im_i_rel_pn_labels = Variable(
                    torch.zeros(im_i_fg_sample_num + im_i_bg_sample_num).type(
                        torch.LongTensor).cuda(x.get_device()))
                im_i_rel_pn_labels[:im_i_fg_sample_num] = 1

                pn_rel_label.append(im_i_rel_pn_labels)
                pn_pair_score.append(im_i_pair_score)

            result.rel_pn_dists = torch.cat(pn_pair_score, 0)
            result.rel_pn_labels = torch.cat(pn_rel_label, 0)

        box_pair_relevant = F.softmax(box_pair_score, dim=1)
        box_pos_pair_ind = torch.nonzero(box_pair_relevant[:, 1].contiguous(
        ) > box_pair_relevant[:, 0].contiguous()).squeeze()

        if box_pos_pair_ind.data.shape == torch.Size([]):
            return None
        #print('{}/{} trim edges'.format(box_pos_pair_ind.size(0), rel_inds.size(0)))
        result.rel_trim_pos = torch.Tensor([box_pos_pair_ind.size(0)]).cuda(
            box_pos_pair_ind.get_device())
        result.rel_trim_total = torch.Tensor([rel_inds.size(0)
                                              ]).cuda(rel_inds.get_device())

        # filtering relations
        filter_rel_inds = rel_inds[box_pos_pair_ind.data]
        filter_box_pair_feats = box_pair_feats[box_pos_pair_ind.data]
        if self.training:
            filter_rel_labels = result.rel_labels[box_pos_pair_ind.data]
            result.rel_labels = filter_rel_labels

        # message passing between boxes and relations
        #embed(header='mp')
        for _ in range(self.mp_iter_num):
            box_feats = self.message_passing(box_feats, filter_box_pair_feats,
                                             filter_rel_inds)
        box_cls_scores = self.cls_fc(box_feats)
        result.rm_obj_dists = box_cls_scores
        obj_scores, box_classes = F.softmax(box_cls_scores[:, 1:].contiguous(),
                                            dim=1).max(1)
        box_classes += 1  # skip background

        # TODO: add memory module
        # filter_box_pair_feats is to be added to memory
        # fbiilter_box_pair_feats = self.memory_()

        # filter_box_pair_feats is to be added to memory

        # RelationCNN
        filter_box_pair_feats_fc1 = self.relcnn_fc1(filter_box_pair_feats)
        filter_box_pair_score = self.relcnn_fc2(filter_box_pair_feats_fc1)
        if not self.graph_cons:
            filter_box_pair_score = filter_box_pair_score.view(
                -1, 2, self.num_rels)
        result.rel_dists = filter_box_pair_score

        if self.training:
            return result

        pred_scores = F.softmax(result.rel_dists, dim=1)
        """
        filter_dets
        boxes: bbox regression else [num_box, 4]
        obj_scores: [num_box] probabilities for the scores
        obj_classes: [num_box] class labels integer
        rel_inds: [num_rel, 2] TENSOR consisting of (im_ind0, im_ind1)
        pred_scores: [num_rel, num_predicates] including irrelevant class(#relclass + 1)
        """
        return filter_dets(boxes, obj_scores, box_classes,
                           filter_rel_inds[:, 1:].contiguous(), pred_scores)