예제 #1
0
def ransac_voting_layer(mask,
                        vertex,
                        round_hyp_num,
                        inlier_thresh=0.999,
                        confidence=0.99,
                        max_iter=20,
                        min_num=5,
                        max_num=30000):
    '''
    :param mask:      [b,h,w]
    :param vertex:    [b,h,w,vn,2]
    :param round_hyp_num:
    :param inlier_thresh:
    :return: [b,vn,2]
    '''
    b, h, w, vn, _ = vertex.shape
    batch_win_pts = []
    for bi in range(b):
        hyp_num = 0
        cur_mask = (mask[bi]).byte()
        foreground_num = torch.sum(cur_mask)

        # if too few points, just skip it
        if foreground_num < min_num:
            win_pts = torch.zeros([1, vn, 2],
                                  dtype=torch.float32,
                                  device=mask.device)
            batch_win_pts.append(win_pts)  # [1,vn,2]
            continue

        # if too many inliers, we randomly down sample it
        if foreground_num > max_num:
            selection = torch.zeros(cur_mask.shape,
                                    dtype=torch.float32,
                                    device=mask.device).uniform_(0, 1)
            selected_mask = (selection < (max_num / foreground_num.float()))
            cur_mask *= selected_mask

        coords = torch.nonzero(cur_mask).float()  # [tn,2]
        coords = coords[:, [1, 0]]
        direct = vertex[bi].masked_select(
            torch.unsqueeze(torch.unsqueeze(cur_mask, 2), 3))  # [tn,vn,2]
        direct = direct.view([coords.shape[0], vn, 2])
        tn = coords.shape[0]
        idxs = torch.zeros([round_hyp_num, vn, 2],
                           dtype=torch.int32,
                           device=mask.device).random_(0, direct.shape[0])
        all_win_ratio = torch.zeros([vn],
                                    dtype=torch.float32,
                                    device=mask.device)
        all_win_pts = torch.zeros([vn, 2],
                                  dtype=torch.float32,
                                  device=mask.device)

        cur_iter = 0
        while True:
            # generate hypothesis
            cur_hyp_pts = ransac_voting.generate_hypothesis(
                direct, coords, idxs)  # [hn,vn,2]

            # voting for hypothesis
            cur_inlier = torch.zeros([round_hyp_num, vn, tn],
                                     dtype=torch.uint8,
                                     device=mask.device)
            ransac_voting.voting_for_hypothesis(direct, coords, cur_hyp_pts,
                                                cur_inlier,
                                                inlier_thresh)  # [hn,vn,tn]

            # find max
            cur_inlier_counts = torch.sum(cur_inlier, 2)  # [hn,vn]
            cur_win_counts, cur_win_idx = torch.max(cur_inlier_counts,
                                                    0)  # [vn]
            cur_win_pts = cur_hyp_pts[cur_win_idx, torch.arange(vn)]
            cur_win_ratio = cur_win_counts.float() / tn

            # update best point
            larger_mask = all_win_ratio < cur_win_ratio
            all_win_pts[larger_mask, :] = cur_win_pts[larger_mask, :]
            all_win_ratio[larger_mask] = cur_win_ratio[larger_mask]

            # check confidence
            hyp_num += round_hyp_num
            cur_iter += 1
            cur_min_ratio = torch.min(all_win_ratio)
            if (1 - (1 - cur_min_ratio**2)**
                    hyp_num) > confidence or cur_iter > max_iter:
                break

        # compute mean intersection again
        normal = torch.zeros_like(direct)  # [tn,vn,2]
        normal[:, :, 0] = direct[:, :, 1]
        normal[:, :, 1] = -direct[:, :, 0]
        all_inlier = torch.zeros([1, vn, tn],
                                 dtype=torch.uint8,
                                 device=mask.device)
        all_win_pts = torch.unsqueeze(all_win_pts, 0)  # [1,vn,2]
        ransac_voting.voting_for_hypothesis(direct, coords, all_win_pts,
                                            all_inlier,
                                            inlier_thresh)  # [1,vn,tn]

        # coords [tn,2] normal [vn,tn,2]
        all_inlier = torch.squeeze(all_inlier.float(), 0)  # [vn,tn]
        normal = normal.permute(1, 0, 2)  # [vn,tn,2]
        normal = normal * torch.unsqueeze(all_inlier,
                                          2)  # [vn,tn,2] outlier is all zero

        b = torch.sum(normal * torch.unsqueeze(coords, 0), 2)  # [vn,tn]
        ATA = torch.matmul(normal.permute(0, 2, 1), normal)  # [vn,2,2]
        ATb = torch.sum(normal * torch.unsqueeze(b, 2), 1)  # [vn,2]
        try:
            all_win_pts = torch.matmul(torch.inverse(ATA),
                                       torch.unsqueeze(ATb, 2))  # [vn,2,1]
            batch_win_pts.append(all_win_pts[None, :, :, 0])
        except:
            all_win_pts = torch.zeros([1, ATA.size(0), 2]).to(ATA.device)
            batch_win_pts.append(all_win_pts)

    batch_win_pts = torch.cat(batch_win_pts)

    return batch_win_pts
예제 #2
0
def estimate_voting_distribution_with_mean(mask, vertex, mean, round_hyp_num=256, min_hyp_num=4096, topk=128, inlier_thresh=0.99, min_num=5, max_num=30000, output_hyp=False):
    b, h, w, vn, _ = vertex.shape
    all_hyp_pts, all_inlier_ratio = [], []
    for bi in range(b):
        k = 0
        cur_mask = mask[bi] == k + 1
        foreground = torch.sum(cur_mask)

        # if too few points, just skip it
        if foreground < min_num:
            cur_hyp_pts = torch.zeros([1, min_hyp_num, vn, 2], dtype=torch.float32, device=mask.device).float()
            all_hyp_pts.append(cur_hyp_pts)  # [1,vn,2]
            cur_inlier_ratio = torch.ones([1, min_hyp_num, vn], dtype=torch.int64, device=mask.device).float()
            all_inlier_ratio.append(cur_inlier_ratio)
            continue

        # if too many inliers, we randomly down sample it
        if foreground > max_num:
            selection = torch.zeros(cur_mask.shape, dtype=torch.float32, device=mask.device).uniform_(0, 1)
            selected_mask = (selection < (max_num / foreground.float()))
            cur_mask *= selected_mask
            foreground = torch.sum(cur_mask)

        coords = torch.nonzero(cur_mask).float()  # [tn,2]
        coords = coords[:, [1, 0]]
        direct = vertex[bi].masked_select(torch.unsqueeze(torch.unsqueeze(cur_mask, 2), 3))  # [tn,vn,2]
        direct = direct.view([coords.shape[0], vn, 2])
        tn = coords.shape[0]

        round_num = np.ceil(min_hyp_num/round_hyp_num)
        cur_hyp_pts = []
        cur_inlier_ratio = []
        for round_idx in range(int(round_num)):
            idxs = torch.zeros([round_hyp_num, vn, 2], dtype=torch.int32, device=mask.device).random_(0, direct.shape[0])

            # generate hypothesis
            hyp_pts = ransac_voting.generate_hypothesis(direct, coords, idxs)  # [hn,vn,2]

            # voting for hypothesis
            inlier = torch.zeros([round_hyp_num, vn, tn], dtype=torch.uint8, device=mask.device)
            ransac_voting.voting_for_hypothesis(direct, coords, hyp_pts, inlier, inlier_thresh)  # [hn,vn,tn]
            inlier_ratio = torch.sum(inlier, 2)                     # [hn,vn]
            inlier_ratio = inlier_ratio.float()/foreground.float()    # ratio

            cur_hyp_pts.append(hyp_pts)
            cur_inlier_ratio.append(inlier_ratio)

        cur_hyp_pts = torch.cat(cur_hyp_pts, 0)
        cur_inlier_ratio = torch.cat(cur_inlier_ratio, 0)
        all_hyp_pts.append(torch.unsqueeze(cur_hyp_pts, 0))
        all_inlier_ratio.append(torch.unsqueeze(cur_inlier_ratio, 0))

    all_hyp_pts = torch.cat(all_hyp_pts, 0)               # b,hn,vn,2
    all_inlier_ratio = torch.cat(all_inlier_ratio, 0)     # b,hn,vn

    # raw_hyp_pts=all_hyp_pts.permute(0,2,1,3).clone()
    # raw_hyp_ratio=all_inlier_ratio.permute(0,2,1).clone()

    all_hyp_pts = all_hyp_pts.permute(0, 2, 1, 3)            # b,vn,hn,2
    all_inlier_ratio = all_inlier_ratio.permute(0, 2, 1)    # b,vn,hn
    thresh = torch.max(all_inlier_ratio, 2)[0]-0.1         # b,vn
    all_inlier_ratio[all_inlier_ratio < torch.unsqueeze(thresh, 2)] = 0.0


    diff_pts = all_hyp_pts-torch.unsqueeze(mean, 2)                  # b,vn,hn,2
    weighted_diff_pts = diff_pts * torch.unsqueeze(all_inlier_ratio, 3)
    cov = torch.matmul(diff_pts.transpose(2, 3), weighted_diff_pts)  # b,vn,2,2
    cov /= torch.unsqueeze(torch.unsqueeze(torch.sum(all_inlier_ratio, 2), 2), 3)+1e-3 # b,vn,2,2

    # if output_hyp:
    #     return mean,cov,all_hyp_pts,all_inlier_ratio,raw_hyp_pts,raw_hyp_ratio

    return mean, cov