def union_pairs(self, im_inds): rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 rel_inds = rel_cands.nonzero() rel_inds = torch.cat((im_inds[rel_inds[:, 0]][:, None].data, rel_inds), -1) return rel_inds
def get_rel_inds(self, rel_labels, im_inds, box_priors): """ Get the relationship candidates :param rel_labels: array of relation labels :param im_inds: image indices :param box_priors: RoI bounding boxes :return rel_inds """ if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds
def get_rel_inds(self, rel_labels, im_inds, box_priors): """Get relation index Args: rel_labels: Variable im_inds: Variable box_priors: Variable """ # Get the relationship candidates if self.training: rel_inds = rel_labels[:, :3].data.contiguous().clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) # if there are fewer then 100 things then we might as well add some? amt_to_add = 100 - rel_cands.long().sum() rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat( (im_inds.data[rel_cands[:, 0]][:, None].contiguous(), rel_cands), 1) return rel_inds
def get_msg_rel_inds(self, im_inds, box_priors, box_score): rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > conf.overlap_thresh) rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds
def get_rel_inds(self, rel_labels, im_inds, box_priors, box_score): if self.training: rel_inds = rel_labels[:, :3].data.clone() else: rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection # Require overlap in the test stage if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(box_priors.data, box_priors.data) > 0) rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: rel_cands = im_inds.data.new(1, 2).fill_(0) rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) return rel_inds
def rel_assignments_sgcls(rois, gt_boxes, gt_classes, gt_rels, image_offset): """ sample_rels to balance proportion of positive and negative samples :param rois: [img_ind, x1, y1, x2, y2] :param gt_boxes: [num_boxes, 4] array of x0, y0, x1, y1]. Not needed it seems :param gt_classes: [num_boxes, 2] array of [img_ind, class] Note, the img_inds here start at image_offset :param gt_rels [num_boxes, 4] array of [img_ind, box_0, box_1, rel type]. Note, the img_inds here start at image_offset :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH) :return: rois: [num_rois, 5] labels: [num_rois] array of labels rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) """ im_inds = rois[:,0].long() num_im = im_inds[-1] + 1 # Offset the image indices in fg_rels to refer to absolute indices (not just within img i) fg_rels = gt_rels.clone() fg_rels[:,0] -= image_offset offset = {} for i, s, e in enumerate_by_image(im_inds): offset[i] = s for i, s, e in enumerate_by_image(fg_rels[:, 0]): fg_rels[s:e, 1:3] += offset[i] # Try ALL things, not just intersections. is_cand = (im_inds[:, None] == im_inds[None]) is_cand.view(-1)[diagonal_inds(is_cand)] = 0 # # Compute salience # gt_inds = fg_rels[:, 1:3].contiguous().view(-1) # labels_arange = labels.data.new(labels.size(0)) # torch.arange(0, labels.size(0), out=labels_arange) # salience_labels = ((gt_inds[:, None] == labels_arange[None]).long().sum(0) > 0).long() # labels = torch.stack((labels, salience_labels), 1) # Add in some BG labels # NOW WE HAVE TO EXCLUDE THE FGs. # TODO: check if this causes an error if many duplicate GTs havent been filtered out is_cand.view(-1)[fg_rels[:,1]*im_inds.size(0) + fg_rels[:,2]] = 0 is_bgcand = is_cand.nonzero() # TODO: make this sample on a per image case # If too many then sample num_fg = min(fg_rels.size(0), int(RELS_PER_IMG * REL_FG_FRACTION * num_im)) if num_fg < fg_rels.size(0): fg_rels = random_choose(fg_rels, num_fg) # If too many then sample num_bg = min(is_bgcand.size(0) if is_bgcand.dim() > 0 else 0, int(RELS_PER_IMG * num_im) - num_fg) if num_bg > 0: bg_rels = torch.cat(( im_inds[is_bgcand[:, 0]][:, None], is_bgcand, (is_bgcand[:, 0, None] < -10).long(), ), 1) if num_bg < is_bgcand.size(0): bg_rels = random_choose(bg_rels, num_bg) rel_labels = torch.cat((fg_rels, bg_rels), 0) else: rel_labels = fg_rels # last sort by rel. _, perm = torch.sort(rel_labels[:, 0]*(gt_boxes.size(0)**2) + rel_labels[:,1]*gt_boxes.size(0) + rel_labels[:,2]) rel_labels = rel_labels[perm].contiguous() labels = gt_classes[:,1].contiguous() return rois, labels, rel_labels
def rel_anchor_target(rois, gt_boxes, gt_classes, scores, gt_rels, image_offset): """ use all roi pairs and sample some pairs to train relation proposal module Note: ONLY for mode SGDET!!!! rois are from RPN, We take the CO_Overlap strategy from Graph-RCNN to sample fg and bg rels :param rois: N, 5 :param scores: N, N :param gt_rels: :return: """ im_inds = rois[:, 0].long() num_im = im_inds[-1] + 1 # Offset the image indices in fg_rels to refer to absolute indices (not just within img i) fg_rels = gt_rels.clone() fg_rels[:, 0] -= image_offset offset = {} for i, s, e in enumerate_by_image(gt_classes[:, 0]): offset[i] = s for i, s, e in enumerate_by_image(fg_rels[:, 0]): fg_rels[s:e, 1:3] += offset[i] gt_box_pairs = torch.cat( (gt_boxes[fg_rels[:, 1]], gt_boxes[fg_rels[:, 2]]), 1) # Ngtp, 8 # get all potential pairs is_cand = (im_inds[:, None] == im_inds[None]) is_cand.view(-1)[diagonal_inds(is_cand)] = 0 all_pair_inds = torch.nonzero(is_cand) all_box_pairs = torch.cat( (rois[:, 1:][all_pair_inds[:, 0]], rois[:, 1:][all_pair_inds[:, 1]]), 1) num_pairs = np.zeros(num_im + 1).astype(np.int32) id_to_iminds = {} for i, s, e in enumerate_by_image(im_inds): num_pairs[i + 1] = (e - s) * (e - s - 1) id_to_iminds[i] = im_inds[s] cumsum_num_pairs = np.cumsum(num_pairs).astype(np.int32) all_rel_inds = [] for i in range(1, num_im + 1): all_pair_inds_i = all_pair_inds[ cumsum_num_pairs[i - 1]:cumsum_num_pairs[i]] all_box_pairs_i = all_box_pairs[ cumsum_num_pairs[i - 1]:cumsum_num_pairs[i]] gt_box_pairs_i = gt_box_pairs[torch.nonzero( fg_rels[:, 0] == (i - 1)).view(-1)] labels = gt_rels.new(all_box_pairs_i.size(0)).fill_(-1) overlaps = co_bbox_overlaps(all_box_pairs_i, gt_box_pairs_i) ## Np, Ngtp max_overlaps, argmax_overlaps = torch.max(overlaps, 1) ## Np gt_max_overlaps, _ = torch.max(overlaps, 0) ## Ngtp labels[max_overlaps < 0.15] = 0 gt_max_overlaps[gt_max_overlaps == 0] = 1e-5 # fg rel: for each gt pair, the max overlap anchor is fg keep = torch.sum( overlaps.eq(gt_max_overlaps.view(1, -1).expand_as(overlaps)), 1) # Np if torch.sum(keep) > 0: labels[keep > 0] = 1 # fg rel: above thresh labels[max_overlaps >= 0.25] = 1 num_fg = int(RELPN_BATCHSIZE * RELPN_FG_FRACTION) sum_fg = torch.sum((labels == 1).int()) sum_bg = torch.sum((labels == 0).int()) if sum_fg > num_fg: fg_inds = torch.nonzero(labels == 1).view(-1) rand_num = torch.from_numpy(np.random.permutation( fg_inds.size(0))).type_as(gt_boxes).long() disable_inds = fg_inds[rand_num[:fg_inds.size(0) - num_fg]] labels[disable_inds] = -1 num_bg = RELPN_BATCHSIZE - torch.sum((labels == 1).int()) if sum_bg > num_bg: bg_inds = torch.nonzero(labels == 0).view(-1) rand_num = torch.from_numpy(np.random.permutation( bg_inds.size(0))).type_as(gt_boxes).long() disable_inds = bg_inds[rand_num[:bg_inds.size(0) - num_bg]] labels[disable_inds] = -1 keep_inds = torch.nonzero(labels >= 0).view(-1) labels = labels[keep_inds] all_pair_inds_i = all_pair_inds_i[keep_inds] im_inds_i = torch.LongTensor([id_to_iminds[i - 1]] * keep_inds.size(0)).view(-1, 1).cuda( all_pair_inds.get_device()) all_pair_inds_i = torch.cat( (im_inds_i, all_pair_inds_i, labels.view(-1, 1)), 1) all_rel_inds.append(all_pair_inds_i) all_rel_inds = torch.cat(all_rel_inds, 0) # sort by rel _, perm = torch.sort(all_rel_inds[:, 0] * (rois.size(0)**2) + all_rel_inds[:, 1] * rois.size(0) + all_rel_inds[:, 2]) all_rel_inds = all_rel_inds[perm].contiguous() return all_rel_inds
def proposal_assignments_gtbox(rois, gt_boxes, gt_classes, gt_rels, image_offset, RELS_PER_IMG, sample_factor=-1): """ Assign object detection proposals to ground-truth targets. Produces proposal classification labels and bounding-box regression targets. :param rpn_rois: [img_ind, x1, y1, x2, y2] :param gt_boxes: [num_boxes, 4] array of x0, y0, x1, y1]. Not needed it seems :param gt_classes: [num_boxes, 2] array of [img_ind, class] Note, the img_inds here start at image_offset :param gt_rels [num_boxes, 4] array of [img_ind, box_0, box_1, rel type]. Note, the img_inds here start at image_offset :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH) :return: rois: [num_rois, 5] labels: [num_rois] array of labels bbox_targets [num_rois, 4] array of targets for the labels. rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) """ im_inds = rois[:, 0].long() num_im = im_inds[-1] + 1 # Offset the image indices in fg_rels to refer to absolute indices (not just within img i) fg_rels = gt_rels.clone() fg_rels[:, 0] -= image_offset offset = {} for i, s, e in enumerate_by_image(im_inds): offset[i] = s for i, s, e in enumerate_by_image(fg_rels[:, 0]): fg_rels[s:e, 1:3] += offset[i] # Try ALL things, not just intersections. is_cand = (im_inds[:, None] == im_inds[None]) is_cand.view(-1)[diagonal_inds(is_cand)] = 0 # NOW WE HAVE TO EXCLUDE THE FGs. is_cand.view(-1)[fg_rels[:, 1] * im_inds.size(0) + fg_rels[:, 2]] = 0 is_bgcand = torch.nonzero(is_cand) # TODO: make this sample on a per image case # If too many then sample num_fg = min(fg_rels.size(0), int(RELS_PER_IMG * REL_FG_FRACTION * num_im)) if num_fg < fg_rels.size(0): fg_rels = random_choose(fg_rels, num_fg) # If too many then sample is_train = num_im > 1 # assume num_im = 1 at test time (except for the det mode, which we don't use for now) sample_bg = is_train and sample_factor > -1 num_bg = min( is_bgcand.size(0) if is_bgcand.dim() > 0 else 0, int(num_fg * sample_factor) if sample_bg else (int(RELS_PER_IMG * num_im) - num_fg)) # sample num_fg at training time if num_bg > 0: bg_rels = torch.cat(( im_inds[is_bgcand[:, 0]][:, None], is_bgcand, (is_bgcand[:, 0, None] < -10).long(), ), 1) if num_bg < is_bgcand.size(0): bg_rels = random_choose( bg_rels, num_bg ) # at test time will correspond to the baseline approach rel_labels = torch.cat((fg_rels, bg_rels), 0) else: rel_labels = fg_rels # last sort by rel. _, perm = torch.sort(rel_labels[:, 0] * (gt_boxes.size(0)**2) + rel_labels[:, 1] * gt_boxes.size(0) + rel_labels[:, 2]) rel_labels = rel_labels[perm].contiguous() labels = gt_classes[:, 1].contiguous() return rois, labels, rel_labels
def forward(self, x, im_sizes, image_offset, gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None, return_fmap=False): """ Forward pass for detection :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE] :param im_sizes: A numpy array of (h, w, scale) for each image. :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0) :param gt_boxes: Training parameters: :param gt_boxes: [num_gt, 4] GT boxes over the batch. :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class) :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will be used to compute the training loss. Each (img_ind, fpn_idx) :return: If train: scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels if test: prob dists, boxes, img inds, maxscores, classes """ # Detector result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals, train_anchor_inds, return_fmap=True) if result.is_none(): return ValueError("heck") im_inds = result.im_inds - image_offset # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach() boxes = result.rm_box_priors # .detach() if self.training and result.rel_labels is None: assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None # rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset, filter_non_overlap=True, num_sample_per_gt=1) #torch.cat((result.rel_labels[:,0].contiguous().view(rel_inds.size(0),1),result.rm_obj_labels[result.rel_labels[:,1]].view(rel_inds.size(0),1),result.rm_obj_labels[result.rel_labels[:,2]].view(rel_inds.size(0),1),result.rel_labels[:,3].contiguous().view(rel_inds.size(0),1)),-1) #bbox_overlaps(boxes.data[55:57].contiguous().view(-1,1), boxes.data[8].contiguous().view(-1,1)) rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes) #[275,3], [im_inds, box1_inds, box2_inds] # rois: [#boxes, 5] rois = torch.cat((im_inds[:, None].float(), boxes), 1) # result.rm_obj_fmap: [384, 4096] #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing # BiLSTM result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context( result.rm_obj_fmap, # has been detached above # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere result.rm_obj_dists.detach(), # .detach:Returns a new Variable, detached from the current graph im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None, boxes.data, result.boxes_all if self.mode == 'sgdet' else result.boxes_all) # Post Processing # nl_egde <= 0 if edge_ctx is None: edge_rep = self.post_emb(result.rm_obj_preds) # nl_edge > 0 else: edge_rep = self.post_lstm(edge_ctx) # [384, 4096*2] # Split into subject and object representations edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim) #[384,2,4096] subj_rep = edge_rep[:, 0] # [384,4096] obj_rep = edge_rep[:, 1] # [384,4096] prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]] # prod_rep, rel_inds: [275,4096], [275,3] if self.use_vision: # True when sgdet # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096] vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:]) if self.limit_vision: # False when sgdet # exact value TBD prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) else: prod_rep = prod_rep * vr # [275,4096] if self.use_tanh: # False when sgdet prod_rep = F.tanh(prod_rep) result.rel_dists = self.rel_compress(prod_rep) # [275,51] if self.use_bias: # True when sgdet result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack(( result.rm_obj_preds[rel_inds[:, 1]], result.rm_obj_preds[rel_inds[:, 2]], ), 1)) # Attention: pos should use rm_obj_labes/rel_labels for obj/rel scores; neg should use rm_obj_preds/max_rel_score for obj/rel scores if self.training: judge = result.rel_labels.data[:,3] != 0 if judge.sum() != 0: # gt_rel exit in rel_inds # positive overall score select_rel_inds = torch.arange(rel_inds.size(0)).view(-1,1).long().cuda()[result.rel_labels.data[:,3] != 0] com_rel_inds = rel_inds[select_rel_inds] twod_inds = arange(result.rm_obj_labels.data) * self.num_classes + result.rm_obj_labels.data # dist: [-10,10] result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # only 1/4 of 384 obj_dists will be updated; because only 1/4 objs's labels are not 0 obj_scores0 = result.obj_scores[com_rel_inds[:,1]] obj_scores1 = result.obj_scores[com_rel_inds[:,2]] rel_rep = F.softmax(result.rel_dists[select_rel_inds], dim=1) # result.rel_dists has grad rel_score = rel_rep.gather(1, result.rel_labels[select_rel_inds][:,3].contiguous().view(-1,1)).view(-1) # not use squeeze(); SqueezeBackward, GatherBackward prob_score = rel_score * obj_scores0 * obj_scores1 # negative overall score rel_cands = im_inds.data[:, None] == im_inds.data[None] rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # self relation = 0 if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(boxes.data, boxes.data) > 0) # Require overlap for detection rel_cands = rel_cands.nonzero() # [#, 2] if rel_cands.dim() == 0: print("rel_cands.dim() == 0!!!!!!!!!!!!!!!!!!!!!!!!!!!!") rel_cands = im_inds.data.new(1, 2).fill_(0) # shaped: [1,2], [0, 0] rel_cands = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) # rel_cands' value should be [0, 384] rel_inds_neg = rel_cands vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:]) subj_obj = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]] prod_rep_neg = subj_obj * vr_neg rel_dists_neg = self.rel_compress(prod_rep_neg) all_rel_rep_neg = F.softmax(rel_dists_neg, dim=1) _, pred_classes_argmax_neg = all_rel_rep_neg.data[:,1:].max(1) pred_classes_argmax_neg = pred_classes_argmax_neg + 1 all_rel_pred_neg = torch.cat((rel_inds_neg, pred_classes_argmax_neg.view(-1,1)), 1) ind_old = torch.ones(all_rel_pred_neg.size(0)).byte().cuda() for i in range(com_rel_inds.size(0)): # delete those box pair with same rel type as pos triplets ind_i = (all_rel_pred_neg[:,0] == com_rel_inds[i, 0]) & (all_rel_pred_neg[:,1] == com_rel_inds[i, 1]) & (result.rm_obj_preds.data[all_rel_pred_neg[:,1]] == result.rm_obj_labels.data[com_rel_inds[i, 1]]) & (all_rel_pred_neg[:,2] == com_rel_inds[i, 2]) & (result.rm_obj_preds.data[all_rel_pred_neg[:,2]] == result.rm_obj_labels.data[com_rel_inds[i, 2]]) & (all_rel_pred_neg[:,3] == result.rel_labels.data[select_rel_inds][i,3]) ind_i = (1 - ind_i).byte() ind_old = ind_i & ind_old rel_inds_neg = rel_inds_neg.masked_select(ind_old.view(-1,1).expand(-1,3) == 1).view(-1,3) rel_rep_neg = all_rel_rep_neg.masked_select(Variable(ind_old.view(-1,1).expand(-1,51)) == 1).view(-1,51) pred_classes_argmax_neg = pred_classes_argmax_neg.view(-1,1)[ind_old.view(-1,1) == 1] rel_labels_pred_neg = all_rel_pred_neg.masked_select(ind_old.view(-1,1).expand(-1,4) == 1).view(-1,4) max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).view(-1) # not use squeeze() twod_inds_neg = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data obj_scores_neg = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds_neg] obj_scores0_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,1]]) obj_scores1_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,2]]) all_score_neg = max_rel_score_neg * obj_scores0_neg * obj_scores1_neg # delete those triplet whose score is lower than pos triplets prob_score_neg = all_score_neg[all_score_neg.data > prob_score.data.min()] if (all_score_neg.data > prob_score.data.min()).sum() != 0 else all_score_neg # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0) all_prob = torch.cat((prob_score,prob_score_neg), 0) # Variable, [#pos_inds+#neg_inds, 1] _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True) sorted_flag = flag[sort_prob_inds].view(-1) # can be used to check distribution of pos and neg sorted_all_prob = all_prob[sort_prob_inds] # Variable # positive triplet score pos_exp = sorted_all_prob[sorted_flag == 1] # Variable # negative triplet score neg_exp = sorted_all_prob[sorted_flag == 0] # Variable # determine how many rows will be updated in rel_dists_neg pos_repeat = torch.zeros(1, 1) neg_repeat = torch.zeros(1, 1) for i in range(pos_exp.size(0)): if ( neg_exp.data > pos_exp.data[i] ).sum() != 0: int_part = (neg_exp.data > pos_exp.data[i]).sum() temp_pos_inds = torch.ones(int_part) * i pos_repeat = torch.cat((pos_repeat, temp_pos_inds.view(-1,1)), 0) temp_neg_inds = torch.arange(int_part) neg_repeat = torch.cat((neg_repeat, temp_neg_inds.view(-1,1)), 0) else: temp_pos_inds = torch.ones(1)* i pos_repeat = torch.cat((pos_repeat, temp_pos_inds.view(-1,1)), 0) temp_neg_inds = torch.arange(1) neg_repeat = torch.cat((neg_repeat, temp_neg_inds.view(-1,1)), 0) """ int_part = neg_exp.size(0) // pos_exp.size(0) decimal_part = neg_exp.size(0) % pos_exp.size(0) int_inds = torch.arange(pos_exp.size(0))[:,None].expand_as(torch.Tensor(pos_exp.size(0), int_part)).contiguous().view(-1) int_part_inds = (int(pos_exp.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative if decimal_part == 0: expand_inds = int_part_inds else: expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0) result.pos = pos_exp[expand_inds] result.neg = neg_exp result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda()) """ result.pos = pos_exp[pos_repeat.cuda().long().view(-1)] result.neg = neg_exp[neg_repeat.cuda().long().view(-1)] result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda()) result.ratio = torch.ones(3).cuda() result.ratio[0] = result.ratio[0] * (sorted_flag.nonzero().min() / (prob_score.size(0) + all_score_neg.size(0))) result.ratio[1] = result.ratio[1] * (sorted_flag.nonzero().max() / (prob_score.size(0) + all_score_neg.size(0))) result.ratio[2] = result.ratio[2] * (prob_score.size(0) + all_score_neg.size(0)) return result else: # no gt_rel in rel_inds print("no gt_rel in rel_inds!!!!!!!!!!!!!!!!!!!!!!!!!!!!") ipdb.set_trace() # testing triplet proposal rel_cands = im_inds.data[:, None] == im_inds.data[None] # self relation = 0 rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0 # Require overlap for detection if self.require_overlap: rel_cands = rel_cands & (bbox_overlaps(boxes.data, boxes.data) > 0) rel_cands = rel_cands.nonzero() if rel_cands.dim() == 0: print("rel_cands.dim() == 0!!!!!!!!!!!!!!!!!!!!!!!!!!!!") rel_cands = im_inds.data.new(1, 2).fill_(0) rel_cands = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1) rel_labels_neg = rel_cands rel_inds_neg = rel_cands twod_inds_neg = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data obj_scores_neg = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds_neg] vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:]) subj_obj = subj_rep[rel_inds_neg[:, 1]] * obj_rep[rel_inds_neg[:, 2]] prod_rep_neg = subj_obj * vr_neg rel_dists_neg = self.rel_compress(prod_rep_neg) # negative overall score obj_scores0_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,1]]) obj_scores1_neg = Variable(obj_scores_neg.data[rel_inds_neg[:,2]]) rel_rep_neg = F.softmax(rel_dists_neg, dim=1) _, pred_classes_argmax_neg = rel_rep_neg.data[:,1:].max(1) pred_classes_argmax_neg = pred_classes_argmax_neg + 1 max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).view(-1) # not use squeeze() prob_score_neg = max_rel_score_neg * obj_scores0_neg * obj_scores1_neg result.pos = Variable(torch.zeros(prob_score_neg.size(0)).cuda()) result.neg = prob_score_neg result.anchor = Variable(torch.zeros(prob_score_neg.size(0)).cuda()) result.ratio = torch.ones(3,1).cuda() return result ###################### Testing ########################### # extract corrsponding scores according to the box's preds twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds] # [384] # Bbox regression if self.mode == 'sgdet': bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4) else: # Boxes will get fixed by filter_dets function. bboxes = result.rm_box_priors rel_rep = F.softmax(result.rel_dists, dim=1) # [275, 51] # sort product of obj1 * obj2 * rel return filter_dets(bboxes, result.obj_scores, result.rm_obj_preds, rel_inds[:, 1:], rel_rep)
def proposal_assignments_gtbox(rois, gt_boxes, gt_classes, gt_rels, image_offset, fg_thresh=0.5): """ Assign object detection proposals to ground-truth targets. Produces proposal classification labels and bounding-box regression targets. :param rpn_rois: [img_ind, x1, y1, x2, y2] :param gt_boxes: [num_boxes, 4] array of x0, y0, x1, y1]. Not needed it seems :param gt_classes: [num_boxes, 2.0] array of [img_ind, class] Note, the img_inds here start at image_offset :param gt_rels [num_boxes, 4] array of [img_ind, box_0, box_1, rel type]. Note, the img_inds here start at image_offset :param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH) :return: rois: [num_rois, 5] labels: [num_rois] array of labels bbox_targets [num_rois, 4] array of targets for the labels. rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type) """ im_inds = rois[:, 0].long() num_im = im_inds[-1] + 1 # Offset the image indices in fg_rels to refer to absolute indices (not just within img i) fg_rels = gt_rels.clone() fg_rels[:, 0] -= image_offset offset = {} for i, s, e in enumerate_by_image(im_inds): offset[i] = s for i, s, e in enumerate_by_image(fg_rels[:, 0]): fg_rels[s:e, 1:3] += offset[i] #----------------------------------------------------------------------------# fg_rel_list = [] for i in range(num_im): fg_rel_list.append(sum(fg_rels[:, 0] == i).item()) longest_len = max(fg_rel_list) bg_rel_length = [longest_len - i for i in fg_rel_list] #----------------------------------------------------------------------------# # Try ALL things, not just intersections. is_cand = (im_inds[:, None] == im_inds[None]) is_cand.view(-1)[diagonal_inds(is_cand)] = 0 # # Compute salience # gt_inds = fg_rels[:, ĺeftright:3].contiguous().view(-ĺeftright) # labels_arange = labels.data.new(labels.size(0)) # torch.arange(0, labels.size(0), out=labels_arange) # salience_labels = ((gt_inds[:, None] == labels_arange[None]).long().sum(0) > 0).long() # labels = torch.stack((labels, salience_labels), ĺeftright) # Add in some BG labels # NOW WE HAVE TO EXCLUDE THE FGs. # TODO: check if this causes an error if many duplicate GTs havent been filtered out is_cand.view(-1)[fg_rels[:, 1] * im_inds.size(0) + fg_rels[:, 2]] = 0 is_bgcand = is_cand.nonzero() # TODO: make this sample on a per image case # If too many then sample num_fg = min(fg_rels.size(0), int(RELS_PER_IMG * REL_FG_FRACTION * num_im)) if num_fg < fg_rels.size(0): fg_rels = random_choose(fg_rels, num_fg) # If too many then sample num_bg = min( is_bgcand.size(0) if is_bgcand.dim() > 0 else 0, int(num_fg / 2)) bg_rels = torch.cat(( im_inds[is_bgcand[:, 0]][:, None], is_bgcand, (is_bgcand[:, 0, None] < -10).long(), ), 1) rel_labels = fg_rels for i, j in enumerate(bg_rel_length): if bg_rels[bg_rels[:, 0] == i, :].shape[0] >= j: bg_rel_per_image = random_choose(bg_rels[bg_rels[:, 0] == i, :], j) else: bg_rel_per_image = torch.cat( (bg_rels[bg_rels[:, 0] == i, :], random_choose(bg_rels[bg_rels[:, 0] == i, :], j - bg_rels[bg_rels[:, 0] == i, :].shape[0])), 0) rel_labels = torch.cat((rel_labels, bg_rel_per_image), 0) # last sort by rel. _, perm = torch.sort(rel_labels[:, 0] * (gt_boxes.size(0)**2) + rel_labels[:, 1] * gt_boxes.size(0) + rel_labels[:, 2]) rel_labels = rel_labels[perm].contiguous() labels = gt_classes[:, 1].contiguous() return rois, labels, rel_labels