def _rel_forward_train(self, x, gt_bboxes, gt_labels, gt_rel, gt_instid, rel_train_cfg, semantic_feat=None, im_width=None, im_height=None, debug_image=None, debug_filename=None): assert gt_bboxes.shape[0] == gt_labels.shape[0] == gt_instid.shape[0] combined_bboxes = gt_bboxes combined_labels = gt_labels sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs =\ sample_pairs(combined_bboxes, combined_labels) with torch.no_grad(): mask_roi_extractor = self.mask_roi_extractor[-1] mask_rois = bbox2roi([combined_bboxes]) mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) if self.with_semantic and 'mask' in self.semantic_fusion: mask_semantic_feat = self.semantic_roi_extractor( [semantic_feat], mask_rois) if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]: mask_semantic_feat = F.adaptive_avg_pool2d( mask_semantic_feat, mask_feats.shape[-2:]) mask_feats += mask_semantic_feat relations, relation_targets = \ assign_pairs(sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs, gt_bboxes, gt_labels, gt_rel, gt_instid) num_relations = relations.shape[0] num_relation_cats = 31 # hard coded targets = relations.new_zeros((num_relations, num_relation_cats), dtype=torch.long) for i in range(num_relations): targets[i, relation_targets[i]] = 1 # get union bboxes sbj_bboxes = combined_bboxes[relations[:, -2].long(), :] obj_bboxes = combined_bboxes[relations[:, -1].long(), :] union_bboxes = bbox_union(sbj_bboxes, obj_bboxes) union_rois = bbox2roi([union_bboxes]) # get visual features sbj_feats = mask_feats[relations[:, -2].long(), ...] obj_feats = mask_feats[relations[:, -1].long(), ...] with torch.no_grad(): union_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], union_rois) if self.with_semantic and 'mask' in self.semantic_fusion: union_semantic_feat = self.semantic_roi_extractor( [semantic_feat], union_rois) if union_semantic_feat.shape[-2:] != union_feats.shape[-2:]: union_semantic_feat = F.adaptive_avg_pool2d( union_semantic_feat, union_feats.shape[-2:]) union_feats += union_semantic_feat ''' ## debug if debug_image is not None: import os import matplotlib.pyplot as plt debug_dir = os.path.join('./tmp/{}'.format(debug_filename[:-4])) if not os.path.exists(debug_dir): os.makedirs(debug_dir) debug_num_relation = relations.shape[0] for i in range(debug_num_relation): sbj_label = relations[i, 0].to(dtype=torch.int) obj_label = relations[i, 1].to(dtype=torch.int) rel_label = relations[i, 2].to(dtype=torch.int) sbj_index = relations[i, 3].to(dtype=torch.int) obj_index = relations[i, 4].to(dtype=torch.int) sbj_bbox = combined_bboxes[sbj_index, :4] obj_bbox = combined_bboxes[obj_index, :4] union_bbox = union_bboxes[i, :4] # plot fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(debug_image, cmap=plt.cm.gray) x1, y1, x2, y2 = sbj_bbox ax.add_artist(plt.Rectangle((x1, y1), x2-x1+1, y2-y1+1, fill=False, color='r')) sbj_text = PicDatasetV20.CATEGORIES[sbj_label-1]['name'] ax.add_artist(plt.Text(x1, y1, sbj_text, size='x-large', color='r')) x1, y1, x2, y2 = obj_bbox ax.add_artist(plt.Rectangle((x1, y1), x2-x1+1, y2-y1+1, fill=False, color='g')) obj_text = PicDatasetV20.CATEGORIES[obj_label-1]['name'] ax.add_artist(plt.Text(x1, y1, obj_text, size='x-large', color='g')) x1, y1, x2, y2 = union_bbox ax.add_artist(plt.Rectangle((x1, y1), x2-x1+1, y2-y1+1, fill=False, color='w')) rel_text = PicDatasetV20.REL_CATEGORIES[rel_label-1]['name'] \ if rel_label > 0 else 'None' ax.add_artist(plt.Text(x1, y1, rel_text, size='x-large', color='w', bbox=dict(facecolor='k', alpha=0.5))) rel_text = PicDatasetV20.REL_CATEGORIES[rel_label-1]['name'] \ if rel_label > 0 else 'None' title = '<{}, {}, {}>'.format(sbj_text, obj_text, rel_text) ax.set_title(title) ax.axis('off') savename = os.path.join(debug_dir, '{:05d}.png'.format(i)) plt.savefig(savename, dpi=100) plt.close() ''' sbj_feats = sbj_feats.view(num_relations, -1) obj_feats = obj_feats.view(num_relations, -1) union_feats = union_feats.view(num_relations, -1) assert sbj_feats.shape == obj_feats.shape == union_feats.shape visual_features = torch.cat([sbj_feats, obj_feats, union_feats], dim=-1) spatial_features = get_spatial_feature(sbj_bboxes, obj_bboxes, im_width, im_height) prd_vis_scores, prd_bias_scores, prd_spt_scores \ = self.reldn_head(visual_features, sbj_labels=sbj_labels, obj_labels=obj_labels, spt_feat=spatial_features, sbj_feat=sbj_feats, obj_feat=obj_feats) label_weights = prd_vis_scores.new_ones(len(prd_vis_scores), dtype=torch.long) loss = self.reldn_head.loss(prd_vis_scores, targets, label_weights, prd_spt_score=prd_spt_scores) return loss
def _rel_forward_train_binary(self, x, gt_bboxes, gt_labels, gt_rel, gt_instid, debug_image=None, debug_filename=None): assert gt_bboxes.shape[0] == gt_labels.shape[0] == gt_instid.shape[0] try: sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs\ = sample_pairs(gt_bboxes, gt_labels) except: print(gt_bboxes, debug_filename) mask_roi_extractor = self.mask_roi_extractor[-1] with torch.no_grad(): mask_rois = bbox2roi([gt_bboxes]) mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) relations = assign_pairs(sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs, gt_bboxes, gt_labels, gt_rel, gt_instid) try: targets = relations[:, 2].long() except: print(debug_filename, relations, gt_bboxes, sbj_bboxes, obj_bboxes) pos_indexs = (targets > 0).nonzero() neg_indexs = (targets == 0).nonzero() num_neg = len(neg_indexs) num_pos = len(pos_indexs) randperm = torch.randperm(num_neg) neg_indexs = neg_indexs[randperm[:min(num_neg, num_pos * 3)]] indexs = torch.cat((pos_indexs, neg_indexs), dim=0).squeeze(1) relations = relations[indexs, :] targets = relations[:, 2].long() sbj_bboxes = gt_bboxes[relations[:, -2].long(), :4] obj_bboxes = gt_bboxes[relations[:, -1].long(), :4] union_bboxes = bbox_union(sbj_bboxes, obj_bboxes) union_rois = bbox2roi([union_bboxes]) with torch.no_grad(): union_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], union_rois) if debug_image is not None: import os import matplotlib.pyplot as plt debug_dir = os.path.join('./tmp/{}'.format(debug_filename[:-4])) if not os.path.exists(debug_dir): os.makedirs(debug_dir) debug_num_relation = relations.shape[0] for i in range(debug_num_relation): sbj_label = relations[i, 0].to(dtype=torch.int) obj_label = relations[i, 1].to(dtype=torch.int) rel_label = relations[i, 2].to(dtype=torch.int) sbj_index = relations[i, 3].to(dtype=torch.int) obj_index = relations[i, 4].to(dtype=torch.int) sbj_bbox = gt_bboxes[sbj_index, :4] obj_bbox = gt_bboxes[obj_index, :4] union_bbox = union_bboxes[i, :4] # plot fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(debug_image, cmap=plt.cm.gray) x1, y1, x2, y2 = sbj_bbox ax.add_artist( plt.Rectangle((x1, y1), x2 - x1 + 1, y2 - y1 + 1, fill=False, color='r')) sbj_text = PicDatasetV20.CATEGORIES[sbj_label - 1]['name'] ax.add_artist( plt.Text(x1, y1, sbj_text, size='x-large', color='r')) x1, y1, x2, y2 = obj_bbox ax.add_artist( plt.Rectangle((x1, y1), x2 - x1 + 1, y2 - y1 + 1, fill=False, color='g')) obj_text = PicDatasetV20.CATEGORIES[obj_label - 1]['name'] ax.add_artist( plt.Text(x1, y1, obj_text, size='x-large', color='g')) x1, y1, x2, y2 = union_bbox ax.add_artist( plt.Rectangle((x1, y1), x2 - x1 + 1, y2 - y1 + 1, fill=False, color='w')) rel_text = PicDatasetV20.REL_CATEGORIES[rel_label-1]['name']\ if rel_label > 0 else 'None' ax.add_artist( plt.Text(x1, y1, rel_text, size='x-large', color='w', bbox=dict(facecolor='k', alpha=0.5))) # rel_text = PicDatasetV20.REL_CATEGORIES[rel_label-1]['name'] \ if rel_label > 0 else 'None' title = '<{}, {}, {}>'.format(sbj_text, obj_text, rel_text) ax.set_title(title) ax.axis('off') savename = os.path.join(debug_dir, '{:05d}.png'.format(i)) plt.savefig(savename, dpi=100) plt.close() sbj_feats = mask_feats[relations[:, -2].long(), ...] obj_feats = mask_feats[relations[:, -1].long(), ...] num_relations = relations.shape[0] sbj_feats = sbj_feats.view(num_relations, -1) obj_feats = obj_feats.view(num_relations, -1) union_feats = union_feats.view(num_relations, -1) pred = self.reldn_binary_head(sbj_feats, obj_feats, union_feats) loss = self.reldn_binary_head.loss(pred, targets > 0) return loss
def _rel_forward_train(self, x, gt_bboxes, gt_labels, gt_rel, gt_instid, rel_train_cfg, im_width=None, im_height=None, debug_image=None, debug_filename=None): """ :param x: :param gt_bboxes: n x 4 (tensor) :param gt_labels: n x 1 (tensor) :param det_bboxes: m x 5 (tensor) :param gt_rel: k x 3 (tensor) :param gt_instid: k x 1 (tensor) :param rel_train_cfg: :param filename: debug :return: """ assert gt_bboxes.shape[0] == gt_labels.shape[0] == gt_instid.shape[0] combined_bboxes = gt_bboxes combined_labels = gt_labels sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs = \ sample_pairs(combined_bboxes, combined_labels) # extract mask features with torch.no_grad(): bbox_roi_extractor = self.bbox_roi_extractor[-1] bbox_rois = bbox2roi([gt_bboxes]) bbox_feats = bbox_roi_extractor( x[:len(bbox_roi_extractor.featmap_strides)], bbox_rois) # assign candidate pairs to a relation class # (including no-relationship class) relations, relation_targets = \ assign_pairs(sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs, gt_bboxes, gt_labels, gt_rel, gt_instid) num_relations = relations.shape[0] num_relation_cats = 11 # hard coded targets = relations.new_zeros((num_relations, num_relation_cats)) for i in range(num_relations): targets[i, relation_targets[i]] = 1 # get union bboxes sbj_bboxes = combined_bboxes[relations[:, -2].long(), :] obj_bboxes = combined_bboxes[relations[:, -1].long(), :] union_bboxes = bbox_union(sbj_bboxes, obj_bboxes) union_rois = bbox2roi([union_bboxes]) # get visual features sbj_feats = bbox_feats[relations[:, -2].long(), ...] sbj_feats = self.sa(sbj_feats) obj_feats = bbox_feats[relations[:, -1].long(), ...] with torch.no_grad(): union_feats = bbox_roi_extractor( x[:len(bbox_roi_extractor.featmap_strides)], union_rois) sbj_feats = sbj_feats.view(num_relations, -1) obj_feats = obj_feats.view(num_relations, -1) union_feats = union_feats.view(num_relations, -1) assert sbj_feats.shape == obj_feats.shape == union_feats.shape visual_features = torch.cat([sbj_feats, obj_feats, union_feats], dim=-1) spatial_features = get_spatial_feature(sbj_bboxes, obj_bboxes, im_width, im_height) prd_vis_scores, prd_bias_scores, prd_spt_scores \ = self.reldn_head(visual_features, sbj_labels=sbj_labels, obj_labels=obj_labels, spt_feat=spatial_features, sbj_feat=sbj_feats, obj_feat=obj_feats) label_weights = prd_vis_scores.new_ones(len(prd_vis_scores), dtype=torch.long) loss = self.reldn_head.loss(prd_vis_scores, targets, label_weights, prd_spt_score=prd_spt_scores) return loss
def _rel_forward_test(self, x, det_bboxes, det_labels, det_masks, scale_factor, ori_shape, im_width=None, im_height=None, semantic_feat=None): assert det_labels.shape[0] == det_bboxes.shape[0] == det_masks.shape[0] if isinstance(det_masks, torch.Tensor): det_masks = det_masks.sigmoid().cpu().numpy() rel_test_cfg = self.test_cfg.rel run_baseline = rel_test_cfg.run_baseline # build pairs sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs =\ sample_pairs(det_bboxes, det_labels + 1) if sbj_bboxes is None: return None num_pairs = len(sbj_idxs) # extract roi features with torch.no_grad(): mask_roi_extractor = self.mask_roi_extractor[-1] mask_rois = bbox2roi([det_bboxes]) mask_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], mask_rois) if self.with_semantic and 'mask' in self.semantic_fusion: mask_semantic_feat = self.semantic_roi_extractor( [semantic_feat], mask_rois) if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]: mask_semantic_feat = F.adaptive_avg_pool2d( mask_semantic_feat, mask_feats.shape[-2:]) mask_feats += mask_semantic_feat # get visual features sbj_feats = mask_feats[sbj_idxs, ...] obj_feats = mask_feats[obj_idxs, ...] # union union_bboxes = bbox_union(sbj_bboxes, obj_bboxes) union_rois = bbox2roi([union_bboxes]) with torch.no_grad(): union_feats = mask_roi_extractor( x[:len(mask_roi_extractor.featmap_strides)], union_rois) if self.with_semantic and 'mask' in self.semantic_fusion: union_semantic_feat = self.semantic_roi_extractor( [semantic_feat], union_rois) if union_semantic_feat.shape[-2:] != union_feats.shape[-2:]: union_semantic_feat = F.adaptive_avg_pool2d( union_semantic_feat, union_feats.shape[-2:]) union_feats += union_semantic_feat sbj_feats = sbj_feats.view(num_pairs, -1) obj_feats = obj_feats.view(num_pairs, -1) union_feats = union_feats.view(num_pairs, -1) assert sbj_feats.shape == obj_feats.shape == union_feats.shape visual_features = torch.cat([sbj_feats, obj_feats, union_feats], dim=-1) spatial_features = get_spatial_feature(sbj_bboxes, obj_bboxes, im_width, im_height) prd_vis_scores, prd_bias_scores, prd_spt_scores \ = self.reldn_head(visual_features, sbj_labels=sbj_labels, obj_labels=obj_labels, spt_feat=spatial_features, sbj_feat=sbj_feats, obj_feat=obj_feats, run_baseline=run_baseline) # detect relations ret_bboxes = det_bboxes.cpu().numpy() ret_labels = det_labels.cpu().numpy() ret_masks = det_masks ret_relations = [] # fill ret_relations if run_baseline: thresh = rel_test_cfg.thresh sbj_idxs = sbj_idxs.cpu().numpy() obj_idxs = obj_idxs.cpu().numpy() prd_bias_scores = prd_bias_scores.cpu().numpy() # never predict __no_relation__ for frequency prior for i in range(prd_bias_scores.shape[0]): sbj_id = sbj_idxs[i] obj_id = obj_idxs[i] rel_scores = prd_bias_scores[i, :] rel_scores[0] = 0 rel_ids = np.where(rel_scores > thresh) ret_relations.extend( [[sbj_id, obj_id, rel_id, rel_scores[rel_id]] for rel_id in rel_ids]) else: prd_vis_scores = prd_vis_scores.cpu().numpy() prd_bias_scores = prd_bias_scores.cpu().numpy() prd_bias_scores[:, 0] = 0 prd_spt_scores = prd_spt_scores.cpu().numpy() prd_score = (prd_vis_scores + prd_spt_scores) * prd_bias_scores for i in range(prd_score.shape[0]): sbj_id = sbj_idxs[i] obj_id = obj_idxs[i] rel_scores = prd_score[i, :] rel_ids = rel_scores.argsort()[-5:] ret_relations.extend( [[sbj_id, obj_id, rel_ids, rel_scores[rel_ids]]]) ret = { 'ret_bbox': ret_bboxes, 'ret_mask': ret_masks, 'ret_label': ret_labels, 'ret_relation': ret_relations, 'scale_factor': scale_factor, 'ori_shape': ori_shape } return ret
def _rel_forward_test(self, x, det_bboxes, det_labels, scale_factor, ori_shape, filename=None, im_width=None, im_height=None): assert det_labels.shape[0] == det_bboxes.shape[0] rel_test_cfg = self.test_cfg.rel run_baseline = rel_test_cfg.run_baseline # build pairs sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs = \ sample_pairs(det_bboxes, det_labels + 1, overlap=True, overlap_th=0.4, test=True) ret = {'predictions': [], 'hoi_prediction': []} if sbj_bboxes is None or sbj_bboxes.shape[0] == 0: return ret num_pairs = len(sbj_idxs) # extract roi features with torch.no_grad(): bbox_roi_extractor = self.bbox_roi_extractor[-1] bbox_rois = bbox2roi([det_bboxes]) bbox_feats = bbox_roi_extractor( x[:len(bbox_roi_extractor.featmap_strides)], bbox_rois) sbj_feats = bbox_feats[sbj_idxs, ...] obj_feats = bbox_feats[obj_idxs, ...] union_bboxes = bbox_union(sbj_bboxes, obj_bboxes) union_rois = bbox2roi([union_bboxes]) with torch.no_grad(): union_feats = bbox_roi_extractor( x[:len(bbox_roi_extractor.featmap_strides)], union_rois) sbj_feats = sbj_feats.view(num_pairs, -1) obj_feats = obj_feats.view(num_pairs, -1) union_feats = union_feats.view(num_pairs, -1) assert sbj_feats.shape == obj_feats.shape == union_feats.shape visual_features = torch.cat([sbj_feats, obj_feats, union_feats], dim=-1) spatial_features = get_spatial_feature(sbj_bboxes, obj_bboxes, im_width, im_height) prd_vis_scores, prd_bias_scores, prd_spt_scores \ = self.reldn_head(visual_features, sbj_labels=sbj_labels, obj_labels=obj_labels, spt_feat=spatial_features, sbj_feat=sbj_feats, obj_feat=obj_feats, run_baseline=run_baseline) # detect relations ret_bboxes = det_bboxes.cpu().numpy() ret_labels = det_labels.cpu().numpy() thresh = rel_test_cfg.thresh sbj_idxs = sbj_idxs.cpu().numpy() obj_idxs = obj_idxs.cpu().numpy() entity_index_to_output_index = {} unique_sbj_idxs = list(np.unique(sbj_idxs)) unique_obj_idxs = list(np.unique(obj_idxs)) unique_entity_idxs = unique_sbj_idxs + unique_obj_idxs for i, entity_index in enumerate(unique_entity_idxs): entity_index_to_output_index[entity_index] = i bbox = ret_bboxes[entity_index, :4] cat = ret_labels[entity_index] score = ret_bboxes[entity_index, -1] ret['predictions'].append({ 'bbox': bbox.tolist(), 'category_id': str(cat + 1), 'score': float(score) }) if run_baseline: prd_bias_scores = prd_bias_scores.cpu().numpy() for i in range(prd_bias_scores.shape[0]): sbj_id = sbj_idxs[i] obj_id = obj_idxs[i] rel_scores = prd_bias_scores[i, :] rel_scores[0] = 0 obj_cat = ret_labels[obj_id] + 1 if obj_cat == 8 or obj_cat == 7 or obj_cat == 6: # horse rel_ids = [6] # ride elif obj_cat == 9: rel_ids = [8] else: rel_ids = np.where(rel_scores > thresh)[0] sbj_output_index = entity_index_to_output_index[sbj_id] obj_output_index = entity_index_to_output_index[obj_id] for rel_id in rel_ids: ret['hoi_prediction'].append({ 'subject_id': int(sbj_output_index), 'object_id': int(obj_output_index), 'category_id': int(rel_id), 'score': float(rel_scores[rel_id]) }) else: prd_vis_scores = prd_vis_scores.cpu().numpy() prd_bias_scores = prd_bias_scores.cpu().numpy() prd_bias_scores[:, 0] = 0 prd_spt_scores = prd_spt_scores.cpu().numpy() prd_score = (prd_vis_scores + prd_spt_scores) * prd_bias_scores for i in range(prd_score.shape[0]): sbj_id = sbj_idxs[i] obj_id = obj_idxs[i] rel_scores = prd_score[i, :] sbj_score = float(det_bboxes[sbj_id, -1]) obj_score = float(det_bboxes[obj_id, -1]) if sbj_score < 0.4 or obj_score < 0.4: continue obj_cat = ret_labels[obj_id] + 1 rel_ids = np.where(rel_scores > thresh)[0] sbj_output_index = entity_index_to_output_index[sbj_id] obj_output_index = entity_index_to_output_index[obj_id] for rel_id in rel_ids: ret['hoi_prediction'].append({ 'subject_id': int(sbj_output_index), 'object_id': int(obj_output_index), 'category_id': int(rel_id), 'score': float(rel_scores[rel_id]) }) return ret