Beispiel #1
0
def viewFromError(nCam, gtImage, predImage, predPoints, projPoints, splatter, offset=None):
    allPositions = torch.from_numpy(read_ply("example_data/pointclouds/sphere_300.ply", nCam)).to(device=splatter.camera.device)[:, :3].unsqueeze(0)
    device = splatter.camera.device
    focalLength = splatter.camera.focalLength
    width = splatter.camera.width
    height = splatter.camera.height
    sv = splatter.camera.sv

    offset = offset or splatter.camera.focalLength*0.5
    fromP = allPositions * offset

    diff = torch.sum((gtImage - predImage).abs(), dim=-1)
    diff = torch.nn.functional.avg_pool2d(diff.unsqueeze(0), 9, stride=4, padding=4, ceil_mode=False, count_include_pad=False).squeeze(0)
    w = diff.argmax() % diff.shape[0]
    h = diff.argmax() // diff.shape[0]
    w *= 4
    h *= 4
    # average points projected inside this region
    _, knn_idx, _ = operations.group_knn(5, torch.tensor([w, h, 1], dtype=projPoints.dtype, device=projPoints.device).view(1, 1, 3).expand(projPoints.shape[0], -1, -1),
                                         projPoints, unique=False, NCHW=False)
    # B, 1, K
    PN = predPoints.shape[0]
    knn_points = torch.gather(predPoints.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1).expand(-1, -1, -1, predPoints.shape[-1]))
    center = torch.mean(knn_points, dim=-2).to(device=device)
    ups = torch.tensor([0, 0, 1], dtype=center.dtype, device=device).view(1, 1, 3).expand_as(fromP)
    ups = ups + torch.randn_like(ups) * 0.0001
    rotation, position = batchLookAt(fromP, center, ups)
    cameras = []
    for i in range(nCam):
        cam = PinholeCamera(device=device, focalLength=focalLength, width=width, height=height, sv=sv)
        cam.rotation = rotation[:, i, :, :]
        cam.position = position[:, i, :]
        cameras.append(cam)

    return diff.max(), cameras
Beispiel #2
0
def _computeDensity(points, knn_k=33, radius=0.1):
    radius2 = radius*radius
    if points.is_cuda:
        knn_points, knn_idx, distance2 = operations.group_knn(knn_k, points, points, unique=False, NCHW=False)
    else:
        knn_points, knn_idx, distance2 = operations.faiss_knn(knn_k, points, points, NCHW=False)
    knn_points = knn_points[:, :, 1:, :].contiguous().detach()
    knn_idx = knn_idx[:, :, 1:].contiguous()
    distance2 = distance2[:, :, 1:].contiguous()
    # ball query find center
    knn_points = torch.where(distance2.unsqueeze(-1)>radius2, torch.zeros_like(knn_points), knn_points)
    weight = torch.exp(-distance2/radius2/4)
    weight = torch.where(distance2>radius2, torch.zeros_like(weight), weight)
    weight = torch.sum(weight, dim=-1)
    return weight
Beispiel #3
0
    def applyAverageTerm(self, points_data, normals_data, original_points, idxList=None, original_density=None):
        """
        points          B,N,3
        original_points B,N,3
        original_density B,N,1
        """
        points = points_data
        if idxList is not None:
            points = torch.gather(points_data, 1, idxList.expand(-1, -1, points_data.shape[-1]))
            normals = torch.gather(normals_data, 1, idxList.expand(-1, -1, normals_data.shape[-1]))

        PN = points.shape[1]

        knn_k = 16
        if points.is_cuda:
            knn_points, knn_idx, distance2 = operations.group_knn(knn_k, points, original_points, unique=False, NCHW=False)
        else:
            knn_points, knn_idx, distance2 = operations.faiss_knn(knn_k, points, original_points, NCHW=False)
        radius2 = self.repulsion_radius*self.repulsion_radius
        # ball query find center
        knn_points = torch.where(distance2.unsqueeze(-1)>radius2, torch.zeros_like(knn_points), knn_points)
        weight = torch.exp(-distance2/radius2/4)
        weight = torch.where(distance2>radius2, torch.zeros_like(weight), weight)
        # original density term
        if original_density is not None:
            if original_density.dim() == 3:
                original_density = original_density.squeeze(-1)
            original_density_weight = torch.gather(original_density.unsqueeze(1).expand(-1, PN, -1), 2, knn_idx)
            original_density_weight = torch.where(distance2>radius2, torch.zeros_like(original_density_weight), original_density_weight)
            weight = weight * original_density_weight

        weightSum = torch.sum(weight, dim=-1, keepdim=True) + 1e-8
        weight /= weightSum

        # find average
        originalAverage = torch.sum(knn_points * weight.unsqueeze(-1), dim=-2)
        # project to its normal
        update = dot(originalAverage - points, normals, dim=-1).unsqueeze(-1) * normals * self.average_weight
        if idxList is not None:
            points_data.scatter_add_(1, idxList.expand(-1, -1, points_data.shape[-1]), update)
            return
        points += update
Beispiel #4
0
    def forward(self, xyz1, xyz2, **kwargs):
        xyz1 = xyz1.contiguous()
        xyz2 = xyz2.contiguous()
        B, N, C = xyz1.shape
        grouped_points, idx, _ = group_knn(self.nn_size,
                                           xyz1,
                                           xyz1,
                                           unique=True,
                                           NCHW=False)
        group_center = torch.mean(grouped_points, dim=2, keepdim=True)
        grouped_points = grouped_points - group_center
        # fit pca
        allpoints = grouped_points.view(-1, self.nn_size, C).contiguous()
        # BN,C,k
        U, S, V = batch_svd(allpoints)
        # V is BNxCxC, last_u BNxC
        normals = V[:, :, -1].view(B, N, C).detach()
        # FIXME what about the sign of normal
        ptof1 = dot_product((xyz1 - group_center.squeeze(2)), normals, dim=-1)

        # for xyz2 use the same neighborhood
        grouped_points = torch.gather(
            xyz2.unsqueeze(1).expand(-1, N, -1, -1), 2,
            idx.unsqueeze(-1).expand(-1, -1, -1, C))
        group_center = torch.mean(grouped_points, dim=2, keepdim=True)
        grouped_points = grouped_points - group_center
        allpoints = grouped_points.view(-1, self.nn_size, C).contiguous()
        # MB,C,k
        U, S, V = batch_svd(allpoints)
        # V is MBxCxC, last_u MBxC
        normals = V[:, :, -1].view(B, N, C).detach()
        ptof2 = dot_product((xyz2 - group_center.squeeze(2)), normals, dim=-1)
        # compare ptof1 and ptof2 absolute value (absolute value can only determine bent, not direction of bent)
        loss = self.metric(ptof1.abs(), ptof2.abs())
        # # penalize flat->curve
        bent = ptof2 - ptof1
        bent.masked_fill_(bent < 0, 0.0)
        bent = self.metric(bent, torch.zeros_like(bent))
        # bent.masked_fill_(bent<=1.0, 0.0)
        loss += 5 * bent
        return loss
Beispiel #5
0
    def applyProjection(self, points_data, normals_data, nonvisibility_data, idxList=None, decay=1.0):
        if self.projection_weight <= 0:
            return
        batchSize, PN, _ = points_data.shape
        if PN <= 3:
            return
        normals_data = normalize(normals_data)
        knn_k = 33
        sharpness_sigma = self.sharpness_sigma
        projection_weight = self.projection_weight
        rradius = self.projection_radius
        points = points_data
        normals = normals_data
        nonvisibility_data = nonvisibility_data.to(device=points.device)
        nonvisibility = nonvisibility_data
        if idxList is not None:
            points = torch.gather(points, 1, idxList.expand(-1, -1, points.shape[-1]))
            normals = torch.gather(normals, 1, idxList.expand(-1, -1, normals.shape[-1]))
            nonvisibility = torch.gather(nonvisibility, 1, idxList.expand(-1, -1, nonvisibility.shape[-1]))
        PN = points.shape[1]
        rradius2 = rradius**2
        iradius = 1/(rradius2)/4
        # first KNN (B, N, k, c)
        if points.is_cuda:
            knn_points, knn_idx, distance2 = operations.group_knn(knn_k, points, points_data, unique=False, NCHW=False)
        else:
            knn_points, knn_idx, distance2 = operations.faiss_knn(knn_k, points, points_data, NCHW=False)
            # distance2 = distance2 * distance2
        knn_points = knn_points[:, :, 1:, :].contiguous()
        knn_idx = knn_idx[:, :, 1:].contiguous()
        distance2 = distance2[:, :, 1:].contiguous()
        if torch.all(distance2[:, :, 0] > rradius2):
            return
        knn_normals = torch.gather(normals_data.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1).expand(-1, -1, -1, normals.shape[-1]))

        # give invisible points a small weight
        phi = torch.gather(nonvisibility_data.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1)).squeeze(-1)
        # phi = torch.where(phi > 0, torch.full([1, 1, 1], 1.0), torch.full([1, 1, 1], 1.0))
        phi = 1 / (1+phi)**2

        # B, N, k
        theta = torch.exp(-distance2*iradius)
        # B, N, k
        sharpness_bandwidth = max(1e-5, 1-np.cos(sharpness_sigma*180.0/3.1415926, dtype=np.float32))
        sharpness_bandwidth *= sharpness_bandwidth
        # B, N, k
        psi = torch.exp(-torch.pow(1-torch.sum(normals.unsqueeze(2)*knn_normals, dim=-1), 2)/sharpness_bandwidth)
        weight = psi * theta * phi
        weight = torch.where(distance2 > rradius2, torch.zeros_like(weight), weight)
        # B, N, k, dot product
        project_dist_sum = torch.sum((points.unsqueeze(2) - knn_points)*knn_normals, dim=-1)*weight
        # B, N, 1
        project_dist_sum = torch.sum(project_dist_sum, dim=-1, keepdim=True)+1e-10
        # B, N, 1
        project_weight_sum = torch.sum(weight, dim=-1, keepdim=True)+1e-10
        # B, N, c
        normal_sum = torch.sum(knn_normals*weight.unsqueeze(-1), dim=2)

        update_normal = normal_sum/project_weight_sum
        update_normal = normalize(update_normal)
        # too few neighbors or project_weight_sum too small
        update_normal = torch.where((torch.sum(distance2 <= rradius2, dim=-1) < 3).unsqueeze(-1) | (project_weight_sum < 1e-7), torch.zeros_like(update_normal), update_normal)
        point_update = -(update_normal * (project_dist_sum / project_weight_sum))
        point_update *= (self.projection_weight*decay)
        point_update = torch.clamp(point_update, -0.02, 0.02)
        if not _check_values(point_update):
            import pdb; pdb.set_trace()
        # apply this force
        if idxList is not None:
            points_data.scatter_add_(1, idxList.expand(-1, -1, points_data.shape[-1]), point_update)
            return
        points_data += point_update
        if self.verbose:
            saved_variables["projection"] = point_update.cpu()
            saved_variables["pweight"] = weight.cpu()
Beispiel #6
0
    def pointRegularizerLoss(self, points_data, normals_data, nonvisibility_data, idxList=None, include_projection=False, use_density=False):
        if self.repulsion_weight <= 0 and self.projection_weight <= 0:
            return
        batchSize, PN, _ = points_data.shape
        if PN <= 3:
            return
        knn_k = 33
        normals_data = normalize(normals_data)
        points = points_data
        normals = normals_data
        nonvisibility_data = nonvisibility_data.to(device=points.device)
        nonvisibility = nonvisibility_data
        if idxList is not None:
            points = torch.gather(points, 1, idxList.expand(-1, -1, points.shape[-1]))
            nonvisibility = torch.gather(nonvisibility, 1, idxList.expand(-1, -1, nonvisibility.shape[-1]))
            normals = torch.gather(normals, 1, idxList.expand(-1, -1, normals.shape[-1]))

        PN = points.shape[1]
        rradius = self.repulsion_radius
        rradius2 = rradius**2
        # repulsion force to projPoints/cameraPoints
        iradius = 1/(rradius2)/2
        # first KNN (B, N, k, c)
        if points.is_cuda:
            knn_points, knn_idx, distance2 = operations.group_knn(knn_k, points, points_data, unique=False, NCHW=False)
        else:
            knn_points, knn_idx, distance2 = operations.faiss_knn(knn_k, points, points_data, NCHW=False)
            # distance2 = distance2 * distance2
        knn_points = knn_points[:, :, 1:, :].contiguous().detach()
        knn_idx = knn_idx[:, :, 1:].contiguous()
        distance2 = distance2[:, :, 1:].contiguous()
        knn_normals = torch.gather(normals_data.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1).expand(-1, -1, -1, normals.shape[-1]))
        knn_v = knn_points - points.unsqueeze(dim=2)
        # phi, psi and theta are used for finding local plane
        # while only psi is used for repulsion loss weight
        # B, N, k
        phi = torch.gather(nonvisibility_data.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1)).squeeze(-1)
        # visibility = 1 / (nonvisibility+1)
        phi = 1/(phi+1)
        # # quantize phi, either 1 or 0.1
        # phi = torch.where(phi > 0, torch.full([1, 1, 1], 1.0, device=phi.device), torch.full([1, 1, 1], 0.5, device=phi.device))
        psi = torch.exp(-distance2*iradius)
        sharpness_bandwidth = max(1e-5, 1-np.cos(self.sharpness_sigma*180.0/3.1415926, dtype=np.float32))
        sharpness_bandwidth *= sharpness_bandwidth
        # B, N, k
        theta = torch.exp(-torch.pow(1-torch.sum(normals.unsqueeze(2)*knn_normals, dim=-1), 2)/sharpness_bandwidth)
        weight = phi*psi*theta
        weightSum = torch.sum(weight, dim=2, keepdim=True)
        weight /= (weightSum+1e-10)
        # project to local plane
        var = weight.unsqueeze(-1)*(knn_points - torch.sum(weight.unsqueeze(-1)*knn_points, dim=2, keepdim=True))
        # the previous step introduces small numeric error due to weighting
        var = torch.where(var.abs() / torch.max(var.abs(), dim=-1, keepdim=True)[0] < 1e-2, torch.zeros_like(var), var)
        # BN, k, 3
        _, _, V = operations.batch_svd(var.view(-1, knn_k-1, 3))
        V = V.detach()
        totalLoss = 0
        ploss = 0
        rloss = 0
        if include_projection and self.projection_weight > 0:
            # projection minimize distance to the plane
            Vp = V.clone()
            # BN, k, 3, 1
            Vn = Vp.unsqueeze(1)[:, :, :, 2:3]
            # BN, k, 3
            knn_v_p = knn_v.clone()
            # x@V@Vt
            projection_v = torch.matmul(torch.matmul(knn_v_p.view(-1, knn_k-1, 1, 3), Vn), Vn.transpose(-2,-1)).squeeze(-2)
            # BN, k
            distance2 = torch.sum(projection_v*projection_v, dim=-1)
            # B,N,k
            distance2 = distance2.view(batchSize, -1, knn_k-1)
            # weight with visibility and angular, distance similarity
            ploss = torch.mean(distance2*weight.detach())*self.projection_weight
            loss = torch.where(distance2 > rradius2, torch.zeros_like(ploss), ploss)
            totalLoss += ploss

        if self.repulsion_weight > 0:
            # repulsion proj to the first two principle axes, set last column of V to zero
            # BN, 3, 3
            V[:, :, -1] = 0
            # BN, k, 1, 3
            V = V.unsqueeze(-3).expand(-1, knn_k-1, -1, -1)
            # BN, k, 3
            knn_v_r = knn_v.clone()
            knn_v_r.register_hook(lambda x: x.clamp(-0.02, 0.02))
            # BN, k, 3
            repulsion_v = torch.matmul(torch.matmul(knn_v_r.view(-1, knn_k-1, 1, 3), V), V.transpose(-2, -1)).squeeze(-2)
            # repulsion_v = knn_v_r
            # BN, k
            distance2 = torch.sum(repulsion_v * repulsion_v, dim=-1)
            distance2 = distance2.view(batchSize, -1, knn_k-1)
            # loss = torch.exp(-distance2*iradius)
            rloss = 1/torch.sqrt(distance2+1e-4)
            # loss = -distance2
            # loss = 1/(distance2+0.001)
            # (torch.sqrt(distance2+1e-8) - self.repulsion_radius)**2
            rloss = torch.where(distance2 > rradius2, torch.zeros_like(rloss), rloss)
            # B,N,k
            weight = torch.where(distance2 > rradius2, torch.zeros_like(psi), psi)
            if use_density:
                densityWeights = _computeDensity(points)
                weight = weight * densityWeights.unsqueeze(-1)
            weightSum = torch.sum(weight, dim=-1, keepdim=True)+1e-8
            rloss = rloss * weight.detach()
            # B,N
            rloss /= weightSum
            rloss = torch.mean(rloss)*self.repulsion_weight
            totalLoss += rloss

        if include_projection:
            return ploss, rloss
        return totalLoss