def icp_pytorch(src, dst, max_iter, threshold=0.005, ratio=0.5):
    prev_dist = 0

    for i in range(max_iter):
        # 1. Find Nearest Neighbor
        idx, dist = _C.nn_search(src.cuda(), dst.cuda())
        dst_temp = dst[idx]

        # 2. Compute H matrix
        src_center = src.mean(dim=0)
        dst_temp_center = dst_temp.mean(dim=0)
        src_norm = src - src_center
        dst_temp_norm = dst_temp - dst_temp_center
        h_matrix = torch.mm(src_norm.T, dst_temp_norm)

        # 3. SVD
        U, S, V = torch.svd(h_matrix)  # FIXME: very slow

        # 4. Rotation matrix and translation vector
        R = torch.mm(U, V.T)
        t = dst_temp_center - torch.mm(R, src_center.unsqueeze(1)).squeeze()

        # 5. Transform
        src = torch.mm(src, R) + t.unsqueeze(0)
        mean_dist = dist.mean()
        if torch.abs(mean_dist - prev_dist) < threshold:
            break
        prev_dist = mean_dist

    _, mink = torch.topk(-dist, int(src.size(0) * ratio))
    corres = torch.empty(src.size(0), 2)
    corres[:, 0] = torch.arange(src.size(0))
    corres[:, 1] = idx

    return corres[mink].long()
Esempio n. 2
0
def nn_search(query, ref):
    """Nearest neighbor search"""
    idx, dist = _C.nn_search(query, ref)
    N = query.size(0)

    # TODO: post-processing time?
    corres = torch.empty(N, 3)
    corres[:, 0] = torch.arange(N)
    corres[:, 1] = idx
    corres[:, 2] = dist

    return corres
Esempio n. 3
0
def nn_search(query,
              ref,
              ratio=0.5,
              cur_label=None,
              prev_label=None,
              gt=False,
              ignore_label=255):
    """Nearest neighbor search"""
    idx, dist = _C.nn_search(query, ref)
    N = query.size(0)

    # TODO: post-processing time?
    corres = torch.empty(N, 3)
    corres[:, 0] = torch.arange(N)
    corres[:, 1] = idx
    corres[:, 2] = dist

    return corres[:int(N * ratio), :2].long()
Esempio n. 4
0
def icp_pt(src, dst, max_iter, threshold):
    prev_dist = 0
    N = src.size(0)

    for i in range(max_iter):
        # 1. Find Nearest Neighbor
        idx, dist = _C.nn_search(src, dst)  # TODO: to device
        dst_temp = dst[idx]

        # 2. Compute H matrix
        src_center = src.mean(dim=0)
        dst_temp_center = dst_temp.mean(dim=0)
        src_norm = src - src_center
        dst_temp_norm = dst_temp - dst_temp_center
        h_matrix = torch.mm(src_norm.T, dst_temp_norm)

        # 3. SVD
        U, S, V = torch.svd(h_matrix)

        # 4. Rotation matrix and translation vector
        R = torch.mm(U, V.T)
        t = dst_temp_center - torch.mm(R, src_center.unsqueeze(1)).squeeze()

        # 5. Transform
        src = torch.mm(src, R) + t.unsqueeze(0)
        mean_dist = dist.mean()
        if torch.abs(mean_dist - prev_dist).item() < threshold:
            break
        prev_dist = mean_dist

    corres = torch.empty(N, 3)
    corres[:, 0] = torch.arange(N)
    corres[:, 1] = idx
    corres[:, 2] = dist

    return corres