def grasp_inference(cls_prob, box_output, im_info, box_prior = None, topN = False, recover_imscale = True):
    assert box_output.dim() == 2, "Multi-instance batch inference has not been implemented."
    if not topN:
        thresh = 0.5
    else:
        thresh = 0

    # TODO: Inference for anchor free algorithms has not been implemented.
    if box_prior is None:
        raise NotImplementedError("Inference for anchor free algorithms has not been implemented.")

    scores = cls_prob
    normalizer = {'mean': cfg.FCGN.BBOX_NORMALIZE_MEANS, 'std': cfg.FCGN.BBOX_NORMALIZE_STDS}
    box_output = box_unnorm_torch(box_output, normalizer, d_box=5, class_agnostic=True, n_cls=None)

    pred_label = grasp_decode(box_output, box_prior)
    pred_boxes = labels2points(pred_label)

    imshape = np.tile(np.array([im_info[1], im_info[0]]),
                      pred_boxes.shape[:-2] + (int(pred_boxes.size(-2)), int(pred_boxes.size(-1) / 2)))
    imshape = torch.from_numpy(imshape).type_as(pred_boxes)
    keep = (((pred_boxes > imshape) | (pred_boxes < 0)).sum(-1) == 0)
    pred_boxes = pred_boxes[keep]
    scores = scores[keep]

    scores = scores.squeeze()
    pred_boxes = pred_boxes.squeeze()
    if recover_imscale:
        pred_boxes = box_recover_scale_torch(pred_boxes, im_info[3], im_info[2])

    grasps, scores, _ = box_filter(pred_boxes, scores[:, 1], thresh, use_nms = False)
    grasps = np.concatenate((grasps, np.expand_dims(scores, -1)), axis=-1)
    if topN:
        grasps = grasps[:topN]
    return grasps
def objgrasp_inference(o_cls_prob, o_box_output, g_cls_prob, g_box_output, im_info, rois = None,
                       class_agnostic = True, n_classes = None, g_box_prior = None, for_vis = False, topN_g = False,
                       recover_imscale = True):
    """
    :param o_cls_prob: N x N_cls tensor
    :param o_box_output: N x 4 tensor
    :param g_cls_prob: N x K*A x 2 tensor
    :param g_box_output: N x K*A x 5 tensor
    :param im_info: size 4 tensor
    :param rois: N x 4 tensor
    :param g_box_prior: N x K*A * 5 tensor
    :return:

    Note:
    1 This function simultaneously supports ROI-GN with or without object branch. If no object branch, o_cls_prob
    and o_box_output will be none, and object detection results are shown in the form of ROIs.
    2 This function can only detect one image per invoking.
    """
    o_scores = o_cls_prob
    rois = rois[:, 1:5]

    g_scores = g_cls_prob

    if for_vis:
        o_thresh = 0.5
    else:
        o_thresh = 0.
        topN_g = 1

    if not topN_g:
        g_thresh = 0.5
    else:
        g_thresh = 0.

    if rois is None:
        raise RuntimeError("You must specify rois for ROI-GN.")

    if g_box_prior is None:
        raise NotImplementedError("Inference for anchor free algorithms has not been implemented.")

    # infer grasp boxes
    normalizer = {'mean': cfg.FCGN.BBOX_NORMALIZE_MEANS, 'std': cfg.FCGN.BBOX_NORMALIZE_STDS}
    g_box_output = box_unnorm_torch(g_box_output, normalizer, d_box=5, class_agnostic=True, n_cls=None)
    g_box_output = g_box_output.view(g_box_prior.size())
    # N x K*A x 5
    grasp_pred = grasp_decode(g_box_output, g_box_prior)

    # N x K*A x 1
    rois_w = (rois[:, 2] - rois[:, 0]).view(-1).unsqueeze(1).unsqueeze(2).expand_as(grasp_pred[:, :, 0:1])
    rois_h = (rois[:, 3] - rois[:, 1]).view(-1).unsqueeze(1).unsqueeze(2).expand_as(grasp_pred[:, :, 1:2])
    keep_mask = (grasp_pred[:, :, 0:1] > 0) & (grasp_pred[:, :, 1:2] > 0) & \
                (grasp_pred[:, :, 0:1] < rois_w) & (grasp_pred[:, :, 1:2] < rois_h)
    grasp_scores = g_scores.contiguous().view(rois.size(0), -1, 2)
    # N x 1 x 1
    xleft = rois[:, 0].view(-1).unsqueeze(1).unsqueeze(2)
    ytop = rois[:, 1].view(-1).unsqueeze(1).unsqueeze(2)
    # rois offset
    grasp_pred[:, :, 0:1] = grasp_pred[:, :, 0:1] + xleft
    grasp_pred[:, :, 1:2] = grasp_pred[:, :, 1:2] + ytop
    # N x K*A x 8
    grasp_pred_boxes = labels2points(grasp_pred).contiguous().view(rois.size(0), -1, 8)
    # N x K*A
    grasp_pos_scores = grasp_scores[:, :, 1]
    if topN_g:
        # N x K*A
        _, grasp_score_idx = torch.sort(grasp_pos_scores, dim=-1, descending=True)
        _, grasp_idx_rank = torch.sort(grasp_score_idx, dim=-1)
        # N x K*A mask
        topn_grasp = topN_g
        grasp_maxscore_mask = (grasp_idx_rank < topn_grasp)
        # N x topN
        grasp_maxscores = grasp_pos_scores[grasp_maxscore_mask].contiguous().view(rois.size()[:1] + (topn_grasp,))
        # N x topN x 8
        grasp_pred_boxes = grasp_pred_boxes[grasp_maxscore_mask].view(rois.size()[:1] + (topn_grasp, 8))
    else:
        raise NotImplementedError("Now ROI-GN only supports top-N grasp detection for each object.")

    # infer object boxes
    if cfg.TRAIN.COMMON.BBOX_REG:
        if cfg.TRAIN.COMMON.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
            normalizer = {'mean': cfg.TRAIN.COMMON.BBOX_NORMALIZE_MEANS, 'std': cfg.TRAIN.COMMON.BBOX_NORMALIZE_STDS}
            box_output = box_unnorm_torch(o_box_output, normalizer, 4, class_agnostic, n_classes)
            pred_boxes = bbox_transform_inv(rois, box_output, 1)
            pred_boxes = clip_boxes(pred_boxes, im_info, 1)
    else:
        pred_boxes = rois.clone()

    if recover_imscale:
        pred_boxes = box_recover_scale_torch(pred_boxes, im_info[3], im_info[2])
        grasp_pred_boxes = box_recover_scale_torch(grasp_pred_boxes, im_info[3], im_info[2])

    all_box = [[]]
    all_grasp = [[]]
    for j in xrange(1, n_classes):
        if class_agnostic or not cfg.TRAIN.COMMON.BBOX_REG:
            cls_boxes = pred_boxes
        else:
            cls_boxes = pred_boxes[:, j * 4:(j + 1) * 4]
        cls_dets, cls_scores, box_keep_inds = box_filter(cls_boxes, o_scores[:, j], o_thresh, use_nms=True)
        cls_dets = np.concatenate((cls_dets, np.expand_dims(cls_scores, -1)), axis=-1)
        grasps = (grasp_pred_boxes.cpu().numpy())[box_keep_inds]

        if for_vis:
            cls_dets[:, -1] = j
        else:
            grasps = np.squeeze(grasps, axis = 1)
        all_box.append(cls_dets)
        all_grasp.append(grasps)

    if for_vis:
        all_box = np.concatenate(all_box[1:], axis = 0)
        all_grasp = np.concatenate(all_grasp[1:], axis = 0)

    return all_box, all_grasp