예제 #1
0
 def union_pairs(self, im_inds):
     rel_cands = im_inds.data[:, None] == im_inds.data[None]
     rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0
     rel_inds = rel_cands.nonzero()
     rel_inds = torch.cat((im_inds[rel_inds[:, 0]][:, None].data, rel_inds),
                          -1)
     return rel_inds
예제 #2
0
    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        """
        Get the relationship candidates
        :param rel_labels: array of relation labels
        :param im_inds:  image indices
        :param box_priors: RoI bounding boxes
        :return rel_inds
        """
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds
예제 #3
0
    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        """Get relation index
        Args:
            rel_labels: Variable
            im_inds: Variable
            box_priors: Variable
        """
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.contiguous().clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat(
                (im_inds.data[rel_cands[:, 0]][:,
                                               None].contiguous(), rel_cands),
                1)
        return rel_inds
예제 #4
0
    def get_msg_rel_inds(self, im_inds, box_priors, box_score):

        rel_cands = im_inds.data[:, None] == im_inds.data[None]
        rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

        if self.require_overlap:
            rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                   box_priors.data) > conf.overlap_thresh)
        rel_cands = rel_cands.nonzero()
        if rel_cands.dim() == 0:
            rel_cands = im_inds.data.new(1, 2).fill_(0)

        rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds
예제 #5
0
    def get_rel_inds(self, rel_labels, im_inds, box_priors, box_score):

        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            # Require overlap in the test stage
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)
            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds
예제 #6
0
def rel_assignments_sgcls(rois, gt_boxes, gt_classes, gt_rels, image_offset):
    """
    sample_rels to balance proportion of positive and negative samples
    :param rois: [img_ind, x1, y1, x2, y2]
    :param gt_boxes:   [num_boxes, 4] array of x0, y0, x1, y1]. Not needed it seems
    :param gt_classes: [num_boxes, 2] array of [img_ind, class]
        Note, the img_inds here start at image_offset
    :param gt_rels     [num_boxes, 4] array of [img_ind, box_0, box_1, rel type].
        Note, the img_inds here start at image_offset
    :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH)
    :return:
        rois: [num_rois, 5]
        labels: [num_rois] array of labels
        rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
    """
    im_inds = rois[:,0].long()

    num_im = im_inds[-1] + 1

    # Offset the image indices in fg_rels to refer to absolute indices (not just within img i)
    fg_rels = gt_rels.clone()
    fg_rels[:,0] -= image_offset
    offset = {}
    for i, s, e in enumerate_by_image(im_inds):
        offset[i] = s
    for i, s, e in enumerate_by_image(fg_rels[:, 0]):
        fg_rels[s:e, 1:3] += offset[i]

    # Try ALL things, not just intersections.
    is_cand = (im_inds[:, None] == im_inds[None])
    is_cand.view(-1)[diagonal_inds(is_cand)] = 0

    # # Compute salience
    # gt_inds = fg_rels[:, 1:3].contiguous().view(-1)
    # labels_arange = labels.data.new(labels.size(0))
    # torch.arange(0, labels.size(0), out=labels_arange)
    # salience_labels = ((gt_inds[:, None] == labels_arange[None]).long().sum(0) > 0).long()
    # labels = torch.stack((labels, salience_labels), 1)

    # Add in some BG labels

    # NOW WE HAVE TO EXCLUDE THE FGs.
    # TODO: check if this causes an error if many duplicate GTs havent been filtered out

    is_cand.view(-1)[fg_rels[:,1]*im_inds.size(0) + fg_rels[:,2]] = 0
    is_bgcand = is_cand.nonzero()
    # TODO: make this sample on a per image case
    # If too many then sample
    num_fg = min(fg_rels.size(0), int(RELS_PER_IMG * REL_FG_FRACTION * num_im))
    if num_fg < fg_rels.size(0):
        fg_rels = random_choose(fg_rels, num_fg)

    # If too many then sample
    num_bg = min(is_bgcand.size(0) if is_bgcand.dim() > 0 else 0,
                 int(RELS_PER_IMG * num_im) - num_fg)
    if num_bg > 0:
        bg_rels = torch.cat((
            im_inds[is_bgcand[:, 0]][:, None],
            is_bgcand,
            (is_bgcand[:, 0, None] < -10).long(),
        ), 1)

        if num_bg < is_bgcand.size(0):
            bg_rels = random_choose(bg_rels, num_bg)
        rel_labels = torch.cat((fg_rels, bg_rels), 0)
    else:
        rel_labels = fg_rels


    # last sort by rel.
    _, perm = torch.sort(rel_labels[:, 0]*(gt_boxes.size(0)**2) +
                         rel_labels[:,1]*gt_boxes.size(0) + rel_labels[:,2])

    rel_labels = rel_labels[perm].contiguous()

    labels = gt_classes[:,1].contiguous()
    return rois, labels, rel_labels
예제 #7
0
def rel_anchor_target(rois, gt_boxes, gt_classes, scores, gt_rels,
                      image_offset):
    """
    use all roi pairs and sample some pairs to train relation proposal module
    Note: ONLY for mode SGDET!!!!
    rois are from RPN,
    We take the CO_Overlap strategy from Graph-RCNN to sample fg and bg rels
    :param rois: N, 5
    :param scores: N, N
    :param gt_rels:
    :return:
    """
    im_inds = rois[:, 0].long()
    num_im = im_inds[-1] + 1

    # Offset the image indices in fg_rels to refer to absolute indices (not just within img i)
    fg_rels = gt_rels.clone()
    fg_rels[:, 0] -= image_offset
    offset = {}
    for i, s, e in enumerate_by_image(gt_classes[:, 0]):
        offset[i] = s
    for i, s, e in enumerate_by_image(fg_rels[:, 0]):
        fg_rels[s:e, 1:3] += offset[i]

    gt_box_pairs = torch.cat(
        (gt_boxes[fg_rels[:, 1]], gt_boxes[fg_rels[:, 2]]), 1)  # Ngtp, 8

    # get all potential pairs
    is_cand = (im_inds[:, None] == im_inds[None])
    is_cand.view(-1)[diagonal_inds(is_cand)] = 0

    all_pair_inds = torch.nonzero(is_cand)
    all_box_pairs = torch.cat(
        (rois[:, 1:][all_pair_inds[:, 0]], rois[:, 1:][all_pair_inds[:, 1]]),
        1)

    num_pairs = np.zeros(num_im + 1).astype(np.int32)
    id_to_iminds = {}
    for i, s, e in enumerate_by_image(im_inds):
        num_pairs[i + 1] = (e - s) * (e - s - 1)
        id_to_iminds[i] = im_inds[s]
    cumsum_num_pairs = np.cumsum(num_pairs).astype(np.int32)

    all_rel_inds = []
    for i in range(1, num_im + 1):
        all_pair_inds_i = all_pair_inds[
            cumsum_num_pairs[i - 1]:cumsum_num_pairs[i]]
        all_box_pairs_i = all_box_pairs[
            cumsum_num_pairs[i - 1]:cumsum_num_pairs[i]]
        gt_box_pairs_i = gt_box_pairs[torch.nonzero(
            fg_rels[:, 0] == (i - 1)).view(-1)]
        labels = gt_rels.new(all_box_pairs_i.size(0)).fill_(-1)

        overlaps = co_bbox_overlaps(all_box_pairs_i,
                                    gt_box_pairs_i)  ## Np, Ngtp
        max_overlaps, argmax_overlaps = torch.max(overlaps, 1)  ## Np
        gt_max_overlaps, _ = torch.max(overlaps, 0)  ## Ngtp

        labels[max_overlaps < 0.15] = 0
        gt_max_overlaps[gt_max_overlaps == 0] = 1e-5

        # fg rel: for each gt pair, the max overlap anchor is fg
        keep = torch.sum(
            overlaps.eq(gt_max_overlaps.view(1, -1).expand_as(overlaps)),
            1)  # Np
        if torch.sum(keep) > 0:
            labels[keep > 0] = 1

        # fg rel: above thresh
        labels[max_overlaps >= 0.25] = 1

        num_fg = int(RELPN_BATCHSIZE * RELPN_FG_FRACTION)
        sum_fg = torch.sum((labels == 1).int())
        sum_bg = torch.sum((labels == 0).int())

        if sum_fg > num_fg:
            fg_inds = torch.nonzero(labels == 1).view(-1)
            rand_num = torch.from_numpy(np.random.permutation(
                fg_inds.size(0))).type_as(gt_boxes).long()
            disable_inds = fg_inds[rand_num[:fg_inds.size(0) - num_fg]]
            labels[disable_inds] = -1
        num_bg = RELPN_BATCHSIZE - torch.sum((labels == 1).int())

        if sum_bg > num_bg:
            bg_inds = torch.nonzero(labels == 0).view(-1)
            rand_num = torch.from_numpy(np.random.permutation(
                bg_inds.size(0))).type_as(gt_boxes).long()
            disable_inds = bg_inds[rand_num[:bg_inds.size(0) - num_bg]]
            labels[disable_inds] = -1

        keep_inds = torch.nonzero(labels >= 0).view(-1)
        labels = labels[keep_inds]
        all_pair_inds_i = all_pair_inds_i[keep_inds]

        im_inds_i = torch.LongTensor([id_to_iminds[i - 1]] *
                                     keep_inds.size(0)).view(-1, 1).cuda(
                                         all_pair_inds.get_device())
        all_pair_inds_i = torch.cat(
            (im_inds_i, all_pair_inds_i, labels.view(-1, 1)), 1)
        all_rel_inds.append(all_pair_inds_i)

    all_rel_inds = torch.cat(all_rel_inds, 0)
    # sort by rel
    _, perm = torch.sort(all_rel_inds[:, 0] * (rois.size(0)**2) +
                         all_rel_inds[:, 1] * rois.size(0) +
                         all_rel_inds[:, 2])
    all_rel_inds = all_rel_inds[perm].contiguous()
    return all_rel_inds
예제 #8
0
def proposal_assignments_gtbox(rois,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               image_offset,
                               RELS_PER_IMG,
                               sample_factor=-1):
    """
    Assign object detection proposals to ground-truth targets. Produces proposal
    classification labels and bounding-box regression targets.
    :param rpn_rois: [img_ind, x1, y1, x2, y2]
    :param gt_boxes:   [num_boxes, 4] array of x0, y0, x1, y1]. Not needed it seems
    :param gt_classes: [num_boxes, 2] array of [img_ind, class]
        Note, the img_inds here start at image_offset
    :param gt_rels     [num_boxes, 4] array of [img_ind, box_0, box_1, rel type].
        Note, the img_inds here start at image_offset
    :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH)
    :return:
        rois: [num_rois, 5]
        labels: [num_rois] array of labels
        bbox_targets [num_rois, 4] array of targets for the labels.
        rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
    """
    im_inds = rois[:, 0].long()

    num_im = im_inds[-1] + 1

    # Offset the image indices in fg_rels to refer to absolute indices (not just within img i)
    fg_rels = gt_rels.clone()
    fg_rels[:, 0] -= image_offset
    offset = {}
    for i, s, e in enumerate_by_image(im_inds):
        offset[i] = s
    for i, s, e in enumerate_by_image(fg_rels[:, 0]):
        fg_rels[s:e, 1:3] += offset[i]

    # Try ALL things, not just intersections.
    is_cand = (im_inds[:, None] == im_inds[None])
    is_cand.view(-1)[diagonal_inds(is_cand)] = 0

    # NOW WE HAVE TO EXCLUDE THE FGs.
    is_cand.view(-1)[fg_rels[:, 1] * im_inds.size(0) + fg_rels[:, 2]] = 0
    is_bgcand = torch.nonzero(is_cand)

    # TODO: make this sample on a per image case
    # If too many then sample
    num_fg = min(fg_rels.size(0), int(RELS_PER_IMG * REL_FG_FRACTION * num_im))
    if num_fg < fg_rels.size(0):
        fg_rels = random_choose(fg_rels, num_fg)

    # If too many then sample
    is_train = num_im > 1  # assume num_im = 1 at test time (except for the det mode, which we don't use for now)
    sample_bg = is_train and sample_factor > -1

    num_bg = min(
        is_bgcand.size(0) if is_bgcand.dim() > 0 else 0,
        int(num_fg * sample_factor) if sample_bg else
        (int(RELS_PER_IMG * num_im) -
         num_fg))  # sample num_fg at training time

    if num_bg > 0:
        bg_rels = torch.cat((
            im_inds[is_bgcand[:, 0]][:, None],
            is_bgcand,
            (is_bgcand[:, 0, None] < -10).long(),
        ), 1)

        if num_bg < is_bgcand.size(0):
            bg_rels = random_choose(
                bg_rels, num_bg
            )  # at test time will correspond to the baseline approach

        rel_labels = torch.cat((fg_rels, bg_rels), 0)
    else:
        rel_labels = fg_rels

    # last sort by rel.
    _, perm = torch.sort(rel_labels[:, 0] * (gt_boxes.size(0)**2) +
                         rel_labels[:, 1] * gt_boxes.size(0) +
                         rel_labels[:, 2])

    rel_labels = rel_labels[perm].contiguous()

    labels = gt_classes[:, 1].contiguous()
    return rois, labels, rel_labels
예제 #9
0
    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")
        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors    # .detach()   

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None
            # rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        #torch.cat((result.rel_labels[:,0].contiguous().view(rel_inds.size(0),1),result.rm_obj_labels[result.rel_labels[:,1]].view(rel_inds.size(0),1),result.rm_obj_labels[result.rel_labels[:,2]].view(rel_inds.size(0),1),result.rel_labels[:,3].contiguous().view(rel_inds.size(0),1)),-1)
        #bbox_overlaps(boxes.data[55:57].contiguous().view(-1,1), boxes.data[8].contiguous().view(-1,1))
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)  #[275,3], [im_inds, box1_inds, box2_inds]
        
        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing

        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,   # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(),  # .detach:Returns a new Variable, detached from the current graph
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all if self.mode == 'sgdet' else result.boxes_all)
        

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else: 
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]
     
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  #[384,2,4096]
        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # prod_rep, rel_inds: [275,4096], [275,3]
    

        if self.use_vision: # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) 
            else:
                prod_rep = prod_rep * vr  # [275,4096]


        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.rm_obj_preds[rel_inds[:, 1]],
                result.rm_obj_preds[rel_inds[:, 2]],
            ), 1))

        # Attention: pos should use rm_obj_labes/rel_labels for obj/rel scores; neg should use rm_obj_preds/max_rel_score for obj/rel scores
        if self.training: 
            judge = result.rel_labels.data[:,3] != 0
            if judge.sum() != 0:  # gt_rel exit in rel_inds
                # positive overall score
                select_rel_inds = torch.arange(rel_inds.size(0)).view(-1,1).long().cuda()[result.rel_labels.data[:,3] != 0]
                com_rel_inds = rel_inds[select_rel_inds]
                twod_inds = arange(result.rm_obj_labels.data) * self.num_classes + result.rm_obj_labels.data  # dist: [-10,10]
                result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # only 1/4 of 384 obj_dists will be updated; because only 1/4 objs's labels are not 0
              
                obj_scores0 = result.obj_scores[com_rel_inds[:,1]]
                obj_scores1 = result.obj_scores[com_rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists[select_rel_inds], dim=1)    # result.rel_dists has grad
                rel_score = rel_rep.gather(1, result.rel_labels[select_rel_inds][:,3].contiguous().view(-1,1)).view(-1)  # not use squeeze(); SqueezeBackward, GatherBackward
                prob_score = rel_score * obj_scores0 * obj_scores1

                # negative overall score
                rel_cands = im_inds.data[:, None] == im_inds.data[None]
                rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0   # self relation = 0
                if self.require_overlap:     
                    rel_cands = rel_cands & (bbox_overlaps(boxes.data, boxes.data) > 0)   # Require overlap for detection
                rel_cands = rel_cands.nonzero()  # [#, 2]
                if rel_cands.dim() == 0:
                    print("rel_cands.dim() == 0!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    rel_cands = im_inds.data.new(1, 2).fill_(0) # shaped: [1,2], [0, 0]
                rel_cands = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) # rel_cands' value should be [0, 384]
                rel_inds_neg = rel_cands

                vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                subj_obj = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]]
                prod_rep_neg =  subj_obj * vr_neg
                rel_dists_neg = self.rel_compress(prod_rep_neg)
                all_rel_rep_neg = F.softmax(rel_dists_neg, dim=1)
                _, pred_classes_argmax_neg = all_rel_rep_neg.data[:,1:].max(1)
                pred_classes_argmax_neg = pred_classes_argmax_neg + 1
                all_rel_pred_neg = torch.cat((rel_inds_neg, pred_classes_argmax_neg.view(-1,1)), 1)
                ind_old = torch.ones(all_rel_pred_neg.size(0)).byte().cuda()
                for i in range(com_rel_inds.size(0)):    # delete those box pair with same rel type as pos triplets
                    ind_i = (all_rel_pred_neg[:,0] == com_rel_inds[i, 0]) & (all_rel_pred_neg[:,1] == com_rel_inds[i, 1]) & (result.rm_obj_preds.data[all_rel_pred_neg[:,1]] == result.rm_obj_labels.data[com_rel_inds[i, 1]]) & (all_rel_pred_neg[:,2] == com_rel_inds[i, 2]) & (result.rm_obj_preds.data[all_rel_pred_neg[:,2]] == result.rm_obj_labels.data[com_rel_inds[i, 2]]) & (all_rel_pred_neg[:,3] == result.rel_labels.data[select_rel_inds][i,3]) 
                    ind_i = (1 - ind_i).byte()
                    ind_old = ind_i & ind_old

                rel_inds_neg = rel_inds_neg.masked_select(ind_old.view(-1,1).expand(-1,3) == 1).view(-1,3)
                rel_rep_neg = all_rel_rep_neg.masked_select(Variable(ind_old.view(-1,1).expand(-1,51)) == 1).view(-1,51)
                pred_classes_argmax_neg = pred_classes_argmax_neg.view(-1,1)[ind_old.view(-1,1) == 1]
                rel_labels_pred_neg = all_rel_pred_neg.masked_select(ind_old.view(-1,1).expand(-1,4) == 1).view(-1,4)

                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).view(-1)  # not use squeeze()
                twod_inds_neg = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                obj_scores_neg = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds_neg] 
                obj_scores0_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,1]])
                obj_scores1_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,2]])
                all_score_neg = max_rel_score_neg * obj_scores0_neg * obj_scores1_neg
                # delete those triplet whose score is lower than pos triplets
                prob_score_neg = all_score_neg[all_score_neg.data > prob_score.data.min()] if (all_score_neg.data > prob_score.data.min()).sum() != 0 else all_score_neg


                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_flag = flag[sort_prob_inds].view(-1)  # can be used to check distribution of pos and neg
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable
                
                # positive triplet score
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable 
                # negative triplet score
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable

                # determine how many rows will be updated in rel_dists_neg
                pos_repeat = torch.zeros(1, 1)
                neg_repeat = torch.zeros(1, 1)
                for i in range(pos_exp.size(0)):
                    if ( neg_exp.data > pos_exp.data[i] ).sum() != 0:
                        int_part = (neg_exp.data > pos_exp.data[i]).sum()
                        temp_pos_inds = torch.ones(int_part) * i
                        pos_repeat =  torch.cat((pos_repeat, temp_pos_inds.view(-1,1)), 0)
                        temp_neg_inds = torch.arange(int_part)
                        neg_repeat = torch.cat((neg_repeat, temp_neg_inds.view(-1,1)), 0)
                    else:
                        temp_pos_inds = torch.ones(1)* i
                        pos_repeat =  torch.cat((pos_repeat, temp_pos_inds.view(-1,1)), 0)
                        temp_neg_inds = torch.arange(1)
                        neg_repeat = torch.cat((neg_repeat, temp_neg_inds.view(-1,1)), 0)

                """
                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.size(0))[:,None].expand_as(torch.Tensor(pos_exp.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                """
                result.pos = pos_exp[pos_repeat.cuda().long().view(-1)]
                result.neg = neg_exp[neg_repeat.cuda().long().view(-1)]
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                

                result.ratio = torch.ones(3).cuda()
                result.ratio[0] = result.ratio[0] * (sorted_flag.nonzero().min() / (prob_score.size(0) + all_score_neg.size(0)))
                result.ratio[1] = result.ratio[1] * (sorted_flag.nonzero().max() / (prob_score.size(0) + all_score_neg.size(0)))
                result.ratio[2] = result.ratio[2] * (prob_score.size(0) + all_score_neg.size(0))

                return result

            else:  # no gt_rel in rel_inds
                print("no gt_rel in rel_inds!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                ipdb.set_trace()
                # testing triplet proposal
                rel_cands = im_inds.data[:, None] == im_inds.data[None]
                # self relation = 0
                rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0
                # Require overlap for detection
                if self.require_overlap:
                    rel_cands = rel_cands & (bbox_overlaps(boxes.data, boxes.data) > 0)
                rel_cands = rel_cands.nonzero()
                if rel_cands.dim() == 0:
                    print("rel_cands.dim() == 0!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    rel_cands = im_inds.data.new(1, 2).fill_(0)
                rel_cands = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
                rel_labels_neg = rel_cands
                rel_inds_neg = rel_cands

                twod_inds_neg = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                obj_scores_neg = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds_neg]
                vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                subj_obj = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]]
                prod_rep_neg = subj_obj * vr_neg
                rel_dists_neg = self.rel_compress(prod_rep_neg)
                # negative overall score
                obj_scores0_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,1]])
                obj_scores1_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,2]])
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,1:].max(1)
                pred_classes_argmax_neg = pred_classes_argmax_neg + 1

                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).view(-1)  # not use squeeze()
                prob_score_neg = max_rel_score_neg * obj_scores0_neg * obj_scores1_neg

                result.pos = Variable(torch.zeros(prob_score_neg.size(0)).cuda())
                result.neg = prob_score_neg
                result.anchor = Variable(torch.zeros(prob_score_neg.size(0)).cuda())

                result.ratio = torch.ones(3,1).cuda()

                return result
        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
        
        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores,
                           result.rm_obj_preds, rel_inds[:, 1:],
                           rel_rep)
예제 #10
0

        
예제 #11
0
def proposal_assignments_gtbox(rois,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               image_offset,
                               fg_thresh=0.5):
    """
    Assign object detection proposals to ground-truth targets. Produces proposal
    classification labels and bounding-box regression targets.
    :param rpn_rois: [img_ind, x1, y1, x2, y2]
    :param gt_boxes:   [num_boxes, 4] array of x0, y0, x1, y1]. Not needed it seems
    :param gt_classes: [num_boxes, 2.0] array of [img_ind, class]
        Note, the img_inds here start at image_offset
    :param gt_rels     [num_boxes, 4] array of [img_ind, box_0, box_1, rel type].
        Note, the img_inds here start at image_offset
    :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH)
    :return:
        rois: [num_rois, 5]
        labels: [num_rois] array of labels
        bbox_targets [num_rois, 4] array of targets for the labels.
        rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
    """
    im_inds = rois[:, 0].long()

    num_im = im_inds[-1] + 1

    # Offset the image indices in fg_rels to refer to absolute indices (not just within img i)
    fg_rels = gt_rels.clone()
    fg_rels[:, 0] -= image_offset
    offset = {}
    for i, s, e in enumerate_by_image(im_inds):
        offset[i] = s
    for i, s, e in enumerate_by_image(fg_rels[:, 0]):
        fg_rels[s:e, 1:3] += offset[i]

    #----------------------------------------------------------------------------#
    fg_rel_list = []
    for i in range(num_im):
        fg_rel_list.append(sum(fg_rels[:, 0] == i).item())
    longest_len = max(fg_rel_list)
    bg_rel_length = [longest_len - i for i in fg_rel_list]
    #----------------------------------------------------------------------------#

    # Try ALL things, not just intersections.
    is_cand = (im_inds[:, None] == im_inds[None])
    is_cand.view(-1)[diagonal_inds(is_cand)] = 0

    # # Compute salience
    # gt_inds = fg_rels[:, ĺeftright:3].contiguous().view(-ĺeftright)
    # labels_arange = labels.data.new(labels.size(0))
    # torch.arange(0, labels.size(0), out=labels_arange)
    # salience_labels = ((gt_inds[:, None] == labels_arange[None]).long().sum(0) > 0).long()
    # labels = torch.stack((labels, salience_labels), ĺeftright)

    # Add in some BG labels

    # NOW WE HAVE TO EXCLUDE THE FGs.
    # TODO: check if this causes an error if many duplicate GTs havent been filtered out

    is_cand.view(-1)[fg_rels[:, 1] * im_inds.size(0) + fg_rels[:, 2]] = 0
    is_bgcand = is_cand.nonzero()
    # TODO: make this sample on a per image case
    # If too many then sample
    num_fg = min(fg_rels.size(0), int(RELS_PER_IMG * REL_FG_FRACTION * num_im))
    if num_fg < fg_rels.size(0):
        fg_rels = random_choose(fg_rels, num_fg)

    # If too many then sample
    num_bg = min(
        is_bgcand.size(0) if is_bgcand.dim() > 0 else 0, int(num_fg / 2))

    bg_rels = torch.cat((
        im_inds[is_bgcand[:, 0]][:, None],
        is_bgcand,
        (is_bgcand[:, 0, None] < -10).long(),
    ), 1)
    rel_labels = fg_rels
    for i, j in enumerate(bg_rel_length):
        if bg_rels[bg_rels[:, 0] == i, :].shape[0] >= j:
            bg_rel_per_image = random_choose(bg_rels[bg_rels[:, 0] == i, :], j)
        else:
            bg_rel_per_image = torch.cat(
                (bg_rels[bg_rels[:, 0] == i, :],
                 random_choose(bg_rels[bg_rels[:, 0] == i, :],
                               j - bg_rels[bg_rels[:, 0] == i, :].shape[0])),
                0)
        rel_labels = torch.cat((rel_labels, bg_rel_per_image), 0)

    # last sort by rel.
    _, perm = torch.sort(rel_labels[:, 0] * (gt_boxes.size(0)**2) +
                         rel_labels[:, 1] * gt_boxes.size(0) +
                         rel_labels[:, 2])

    rel_labels = rel_labels[perm].contiguous()

    labels = gt_classes[:, 1].contiguous()

    return rois, labels, rel_labels