def grasp_inference(cls_prob, box_output, im_info, box_prior = None, topN = False, recover_imscale = True): assert box_output.dim() == 2, "Multi-instance batch inference has not been implemented." if not topN: thresh = 0.5 else: thresh = 0 # TODO: Inference for anchor free algorithms has not been implemented. if box_prior is None: raise NotImplementedError("Inference for anchor free algorithms has not been implemented.") scores = cls_prob normalizer = {'mean': cfg.FCGN.BBOX_NORMALIZE_MEANS, 'std': cfg.FCGN.BBOX_NORMALIZE_STDS} box_output = box_unnorm_torch(box_output, normalizer, d_box=5, class_agnostic=True, n_cls=None) pred_label = grasp_decode(box_output, box_prior) pred_boxes = labels2points(pred_label) imshape = np.tile(np.array([im_info[1], im_info[0]]), pred_boxes.shape[:-2] + (int(pred_boxes.size(-2)), int(pred_boxes.size(-1) / 2))) imshape = torch.from_numpy(imshape).type_as(pred_boxes) keep = (((pred_boxes > imshape) | (pred_boxes < 0)).sum(-1) == 0) pred_boxes = pred_boxes[keep] scores = scores[keep] scores = scores.squeeze() pred_boxes = pred_boxes.squeeze() if recover_imscale: pred_boxes = box_recover_scale_torch(pred_boxes, im_info[3], im_info[2]) grasps, scores, _ = box_filter(pred_boxes, scores[:, 1], thresh, use_nms = False) grasps = np.concatenate((grasps, np.expand_dims(scores, -1)), axis=-1) if topN: grasps = grasps[:topN] return grasps
def objgrasp_inference(o_cls_prob, o_box_output, g_cls_prob, g_box_output, im_info, rois = None, class_agnostic = True, n_classes = None, g_box_prior = None, for_vis = False, topN_g = False, recover_imscale = True): """ :param o_cls_prob: N x N_cls tensor :param o_box_output: N x 4 tensor :param g_cls_prob: N x K*A x 2 tensor :param g_box_output: N x K*A x 5 tensor :param im_info: size 4 tensor :param rois: N x 4 tensor :param g_box_prior: N x K*A * 5 tensor :return: Note: 1 This function simultaneously supports ROI-GN with or without object branch. If no object branch, o_cls_prob and o_box_output will be none, and object detection results are shown in the form of ROIs. 2 This function can only detect one image per invoking. """ o_scores = o_cls_prob rois = rois[:, 1:5] g_scores = g_cls_prob if for_vis: o_thresh = 0.5 else: o_thresh = 0. topN_g = 1 if not topN_g: g_thresh = 0.5 else: g_thresh = 0. if rois is None: raise RuntimeError("You must specify rois for ROI-GN.") if g_box_prior is None: raise NotImplementedError("Inference for anchor free algorithms has not been implemented.") # infer grasp boxes normalizer = {'mean': cfg.FCGN.BBOX_NORMALIZE_MEANS, 'std': cfg.FCGN.BBOX_NORMALIZE_STDS} g_box_output = box_unnorm_torch(g_box_output, normalizer, d_box=5, class_agnostic=True, n_cls=None) g_box_output = g_box_output.view(g_box_prior.size()) # N x K*A x 5 grasp_pred = grasp_decode(g_box_output, g_box_prior) # N x K*A x 1 rois_w = (rois[:, 2] - rois[:, 0]).view(-1).unsqueeze(1).unsqueeze(2).expand_as(grasp_pred[:, :, 0:1]) rois_h = (rois[:, 3] - rois[:, 1]).view(-1).unsqueeze(1).unsqueeze(2).expand_as(grasp_pred[:, :, 1:2]) keep_mask = (grasp_pred[:, :, 0:1] > 0) & (grasp_pred[:, :, 1:2] > 0) & \ (grasp_pred[:, :, 0:1] < rois_w) & (grasp_pred[:, :, 1:2] < rois_h) grasp_scores = g_scores.contiguous().view(rois.size(0), -1, 2) # N x 1 x 1 xleft = rois[:, 0].view(-1).unsqueeze(1).unsqueeze(2) ytop = rois[:, 1].view(-1).unsqueeze(1).unsqueeze(2) # rois offset grasp_pred[:, :, 0:1] = grasp_pred[:, :, 0:1] + xleft grasp_pred[:, :, 1:2] = grasp_pred[:, :, 1:2] + ytop # N x K*A x 8 grasp_pred_boxes = labels2points(grasp_pred).contiguous().view(rois.size(0), -1, 8) # N x K*A grasp_pos_scores = grasp_scores[:, :, 1] if topN_g: # N x K*A _, grasp_score_idx = torch.sort(grasp_pos_scores, dim=-1, descending=True) _, grasp_idx_rank = torch.sort(grasp_score_idx, dim=-1) # N x K*A mask topn_grasp = topN_g grasp_maxscore_mask = (grasp_idx_rank < topn_grasp) # N x topN grasp_maxscores = grasp_pos_scores[grasp_maxscore_mask].contiguous().view(rois.size()[:1] + (topn_grasp,)) # N x topN x 8 grasp_pred_boxes = grasp_pred_boxes[grasp_maxscore_mask].view(rois.size()[:1] + (topn_grasp, 8)) else: raise NotImplementedError("Now ROI-GN only supports top-N grasp detection for each object.") # infer object boxes if cfg.TRAIN.COMMON.BBOX_REG: if cfg.TRAIN.COMMON.BBOX_NORMALIZE_TARGETS_PRECOMPUTED: normalizer = {'mean': cfg.TRAIN.COMMON.BBOX_NORMALIZE_MEANS, 'std': cfg.TRAIN.COMMON.BBOX_NORMALIZE_STDS} box_output = box_unnorm_torch(o_box_output, normalizer, 4, class_agnostic, n_classes) pred_boxes = bbox_transform_inv(rois, box_output, 1) pred_boxes = clip_boxes(pred_boxes, im_info, 1) else: pred_boxes = rois.clone() if recover_imscale: pred_boxes = box_recover_scale_torch(pred_boxes, im_info[3], im_info[2]) grasp_pred_boxes = box_recover_scale_torch(grasp_pred_boxes, im_info[3], im_info[2]) all_box = [[]] all_grasp = [[]] for j in xrange(1, n_classes): if class_agnostic or not cfg.TRAIN.COMMON.BBOX_REG: cls_boxes = pred_boxes else: cls_boxes = pred_boxes[:, j * 4:(j + 1) * 4] cls_dets, cls_scores, box_keep_inds = box_filter(cls_boxes, o_scores[:, j], o_thresh, use_nms=True) cls_dets = np.concatenate((cls_dets, np.expand_dims(cls_scores, -1)), axis=-1) grasps = (grasp_pred_boxes.cpu().numpy())[box_keep_inds] if for_vis: cls_dets[:, -1] = j else: grasps = np.squeeze(grasps, axis = 1) all_box.append(cls_dets) all_grasp.append(grasps) if for_vis: all_box = np.concatenate(all_box[1:], axis = 0) all_grasp = np.concatenate(all_grasp[1:], axis = 0) return all_box, all_grasp