Beispiel #1
0
    def compute_gploss(self, zy_in, imgid, batch_id, label_flg=0):
        tensor_mat = zy_in

        Sg_Pred = torch.zeros([self.train_batch_size, 1])
        Sg_Pred = Sg_Pred.cuda()
        LSg_Pred = torch.zeros([self.train_batch_size, 1])
        LSg_Pred = LSg_Pred.cuda()
        sq_err = 0

        for i in range(tensor_mat.shape[0]):
            tmp_i = self.dict_unlbl[
                imgid[i]] if label_flg == 0 else self.dict_lbl[
                    imgid[i]]  # imag_id in the dictionary
            tensor = tensor_mat[i, :, :, :]  # z tensor
            tensor_vec = tensor.view(-1, self.z_height *
                                     self.z_width)  # z tensor to a vector
            ker_UU = self.kernel_comp(
                tensor_vec, tensor_vec, self.z_height, self.z_width,
                self.z_numchnls,
                self.z_numchnls)  # k(z,z), i.e kernel value for z,z

            nearest_vl = self.ker_unlbl[
                tmp_i, :] if label_flg == 0 else self.ker_lbl[
                    tmp_i, :]  #kernel values are used to get neighbors
            tp32_vec = np.array(
                sorted(range(len(nearest_vl)),
                       key=lambda k: nearest_vl[k])[-1 * self.num_nearest:])
            lt32_vec = np.array(
                sorted(range(len(nearest_vl)),
                       key=lambda k: nearest_vl[k])[:self.num_nearest])

            # Nearest neighbor latent space labeled vectors
            near_dic_lbl = np.zeros((self.num_nearest, self.z_numchnls,
                                     self.z_height, self.z_width))
            for j in range(self.num_nearest):
                near_dic_lbl[j, :] = self.Fz_lbl[tp32_vec[j], :, :, :]
            near_vec_lbl = np.reshape(near_dic_lbl,
                                      (self.num_nearest * self.z_numchnls,
                                       self.z_height * self.z_width))
            far_dic_lbl = np.zeros((self.num_nearest, self.z_numchnls,
                                    self.z_height, self.z_width))
            for j in range(self.num_nearest):
                far_dic_lbl[j, :] = self.Fz_lbl[lt32_vec[j], :, :, :]
            far_vec_lbl = np.reshape(far_dic_lbl,
                                     (self.num_nearest * self.z_numchnls,
                                      self.z_height * self.z_width))

            # computing kernel matrix of nearest label vectors
            # and then computing (K_L+sig^2I)^(-1)
            ker_LL = cosine_similarity(near_vec_lbl, near_vec_lbl)
            inv_ker = inv(ker_LL +
                          1.0 * np.eye(self.num_nearest * self.z_numchnls))
            farker_LL = cosine_similarity(far_vec_lbl, far_vec_lbl)
            farinv_ker = inv(farker_LL +
                             1.0 * np.eye(self.num_nearest * self.z_numchnls))

            mn_pre = np.matmul(
                inv_ker, near_vec_lbl)  # used for computing mean prediction
            mn_pre = mn_pre.astype(np.float32)

            #converting require variables to cuda tensors
            near_vec_lbl = torch.from_numpy(near_vec_lbl.astype(np.float32))
            far_vec_lbl = torch.from_numpy(far_vec_lbl.astype(np.float32))
            inv_ker = torch.from_numpy(inv_ker.astype(np.float32))
            farinv_ker = torch.from_numpy(farinv_ker.astype(np.float32))
            inv_ker = inv_ker.cuda()
            farinv_ker = farinv_ker.cuda()
            near_vec_lbl = near_vec_lbl.cuda()
            far_vec_lbl = far_vec_lbl.cuda()

            mn_pre = torch.from_numpy(
                mn_pre)  # used for mean prediction (mu) or z_pseudo
            mn_pre = mn_pre.cuda()

            # Identity matrix
            Eye = torch.eye(self.z_numchnls)
            Eye = Eye.cuda()

            # computing sigma or variance between nearest labeled vectors and unlabeled vector
            ker_UL = self.kernel_comp(tensor_vec, near_vec_lbl, self.z_height,
                                      self.z_width, self.z_numchnls,
                                      self.z_numchnls * self.num_nearest)
            sigma_est = ker_UU - torch.matmul(
                ker_UL, torch.matmul(inv_ker, ker_UL.t())) + Eye
            # computing variance between farthest labeled vectors and unlabeled vector
            Farker_UL = self.kernel_comp(tensor_vec, far_vec_lbl,
                                         self.z_height, self.z_width,
                                         self.z_numchnls,
                                         self.z_numchnls * self.num_nearest)
            far_sigma_est = ker_UU - torch.matmul(
                Farker_UL, torch.matmul(farinv_ker, Farker_UL.t())) + Eye

            # computing mean prediction
            mean_pred = torch.matmul(ker_UL,
                                     mn_pre)  #mean prediction (mu) or z_pseudo

            inv_sigma = torch.inverse(sigma_est)
            sq_err += torch.mean(
                torch.matmul((tensor_vec - mean_pred).t(),
                             torch.matmul(inv_sigma, (tensor_vec - mean_pred)))
            ) + 1.0 * self.lambda_var * torch.log(
                torch.det(sigma_est)) - 0.000001 * self.lambda_var * torch.log(
                    torch.det(far_sigma_est))
            Sg_Pred[i, :] = torch.log(torch.det(sigma_est))
            LSg_Pred[i, :] = torch.log(torch.det(far_sigma_est))

        if not (batch_id % 100):
            print(LSg_Pred.max().item(),
                  Sg_Pred.max().item(),
                  sq_err.mean().item() / self.train_batch_size,
                  Sg_Pred.mean().item())

        gp_loss = ((1.0 * sq_err / self.train_batch_size))

        return gp_loss
    def _test_single_corresponding_points_alignment(
        self,
        batch_size=10,
        n_points=100,
        dim=3,
        use_pointclouds=False,
        estimate_scale=False,
        reflect=False,
        allow_reflection=False,
        random_weights=False,
    ):
        """
        Executes a single test for `corresponding_points_alignment` for a
        specific setting of the inputs / outputs.
        """

        device = torch.device("cuda:0")

        # initialize the a ground truth point cloud
        X = TestCorrespondingPointsAlignment.init_point_cloud(
            batch_size=batch_size,
            n_points=n_points,
            dim=dim,
            device=device,
            use_pointclouds=use_pointclouds,
            random_pcl_size=True,
        )

        # generate the true transformation
        R, T, s = TestCorrespondingPointsAlignment.generate_pcl_transformation(
            batch_size=batch_size,
            scale=estimate_scale,
            reflect=reflect,
            dim=dim,
            device=device,
        )

        if reflect:
            # generate random reflection M and apply to the rotations
            M = TestCorrespondingPointsAlignment.generate_random_reflection(
                batch_size=batch_size, dim=dim, device=device)
            R = torch.bmm(M, R)

        weights = None
        if random_weights:
            template = X.points_padded() if use_pointclouds else X
            weights = torch.rand_like(template[:, :, 0])
            weights = weights / weights.sum(dim=1, keepdim=True)
            # zero out some weights as zero weights are a common use case
            # this guarantees there are no zero weight
            weights *= (weights * template.size()[1] > 0.3).to(weights)
            if use_pointclouds:  # convert to List[Tensor]
                weights = [
                    w[:npts]
                    for w, npts in zip(weights, X.num_points_per_cloud())
                ]

        # apply the generated transformation to the generated
        # point cloud X
        X_t = _apply_pcl_transformation(X, R, T, s=s)

        # run the CorrespondingPointsAlignment algorithm
        R_est, T_est, s_est = points_alignment.corresponding_points_alignment(
            X,
            X_t,
            weights,
            allow_reflection=allow_reflection,
            estimate_scale=estimate_scale,
        )

        assert_error_message = (
            f"Corresponding_points_alignment assertion failure for "
            f"n_points={n_points}, "
            f"dim={dim}, "
            f"use_pointclouds={use_pointclouds}, "
            f"estimate_scale={estimate_scale}, "
            f"reflect={reflect}, "
            f"allow_reflection={allow_reflection},"
            f"random_weights={random_weights}.")

        # if we test the weighted case, check that weights help with noise
        if random_weights and not use_pointclouds and n_points >= (dim + 10):
            # add noise to 20% points with smallest weight
            X_noisy = X_t.clone()
            _, mink_idx = torch.topk(-weights, int(n_points * 0.2), dim=1)
            mink_idx = mink_idx[:, :, None].expand(-1, -1, X_t.shape[-1])
            X_noisy.scatter_add_(
                1, mink_idx, 0.3 * torch.randn_like(mink_idx, dtype=X_t.dtype))

            def align_and_get_mse(weights_):
                R_n, T_n, s_n = points_alignment.corresponding_points_alignment(
                    X_noisy,
                    X_t,
                    weights_,
                    allow_reflection=allow_reflection,
                    estimate_scale=estimate_scale,
                )

                X_t_est = _apply_pcl_transformation(X_noisy, R_n, T_n, s=s_n)

                return (((X_t_est - X_t) * weights[..., None])**
                        2).sum(dim=(1, 2)) / weights.sum(dim=-1)

            # check that using weights leads to lower weighted_MSE(X_noisy, X_t)
            self.assertTrue(
                torch.all(
                    align_and_get_mse(weights) <= align_and_get_mse(None)))

        if reflect and not allow_reflection:
            # check that all rotations have det=1
            self._assert_all_close(torch.det(R_est),
                                   R_est.new_ones(batch_size),
                                   assert_error_message)

        else:
            # mask out inputs with too few non-degenerate points for assertions
            w = (torch.ones_like(R_est[:, 0, 0])
                 if weights is None or n_points >= dim + 10 else
                 (weights > 0.0).all(dim=1).to(R_est))
            # check that the estimated tranformation is the same
            # as the ground truth
            if n_points >= (dim + 1):
                # the checks on transforms apply only when
                # the problem setup is unambiguous
                msg = assert_error_message
                self._assert_all_close(R_est,
                                       R,
                                       msg,
                                       w[:, None, None],
                                       atol=1e-5)
                self._assert_all_close(T_est, T, msg, w[:, None])
                self._assert_all_close(s_est, s, msg, w)

                # check that the orthonormal part of the
                # transformation has a correct determinant (+1/-1)
                desired_det = R_est.new_ones(batch_size)
                if reflect:
                    desired_det *= -1.0
                self._assert_all_close(torch.det(R_est), desired_det, msg, w)

            # check that the transformed point cloud
            # X matches X_t
            X_t_est = _apply_pcl_transformation(X, R_est, T_est, s=s_est)
            self._assert_all_close(X_t,
                                   X_t_est,
                                   assert_error_message,
                                   w[:, None, None],
                                   atol=1e-5)
Beispiel #3
0
        jacobian[1, 1, inode] = jacobian[1, 1, inode] * scaling_factor[inode]
        jacobian[1, 2, inode] = jacobian[1, 2, inode] * scaling_factor[inode]

# Matrics and Determinant
metrics = torch.empty(3, 3, nnodes_if, dtype=torch.float64)
jinv = torch.empty(nnodes_if, dtype=torch.float64)
for inode in range(0, nnodes_if):
    ijacobian = torch.empty(3, 3, dtype=torch.float64)
    imetric = torch.empty(3, 3, dtype=torch.float64)
    for irow in range(0, 3):
        for icol in range(0, 3):
            ijacobian[irow, icol] = jacobian[irow, icol, inode]
    # Compute jacobian for the ith node
    update_progress("Computing Jinv and Metric          ",
                    inode / (nnodes_if - 1))
    jinv[inode] = torch.det(ijacobian)
    imetric = torch.inverse(ijacobian)
    for irow in range(0, 3):
        for icol in range(0, 3):
            metrics[irow, icol, inode] = imetric[irow, icol]

# Normals
normals = torch.empty(nnodes_if, ndirs, dtype=torch.float64)
if iface == 1 or iface == 2:
    for inode in range(0, nnodes_if):
        normals[inode, 0] = jinv[inode] * metrics[0, 0, inode]
        normals[inode, 1] = jinv[inode] * metrics[0, 1, inode]
        normals[inode, 2] = jinv[inode] * metrics[0, 2, inode]
        update_progress("Computing Normals                  ",
                        inode / (nnodes_if - 1))
if iface == 3 or iface == 4:
r = (torch.rand(2, 2) - 0.5) * 2  # values between -1 and 1
print('A random matrix, r:')
print(r)

# Common mathematical operations are supported:
print('\nAbsolute value of r:')
print(torch.abs(r))

# ...as are trigonometric functions:
print('\nInverse sine of r:')
print(torch.asin(r))

# ...and linear algebra operations like determinant and singular value decomposition
print('\nDeterminant of r:')
print(torch.det(r))
print('\nSingular value decomposition of r:')
print(torch.svd(r))

# ...and statistical and aggregate operations:
print('\nAverage and standard deviation of r:')
print(torch.std_mean(r))
print('\nMaximum value of r:')
print(torch.max(r))

##########################################################################
# There’s a good deal more to know about the power of PyTorch tensors,
# including how to set them up for parallel computations on GPU - we’ll be
# going into more depth in another video.
#
# PyTorch Models
Beispiel #5
0
    def forward(self,
                src_coords,
                tgt_coords,
                weights,
                convert_from_pixels=True):
        """ This modules used differentiable singular value decomposition to compute the rotations and translations that
            best align matched pointclouds (src and tgt).
        Args:
            src_coords (torch.tensor): (b,N,2) source keypoint locations
            tgt_coords (torch.tensor): (b,N,2) target keypoint locations
            weights (torch.tensor): (b,1,N) weight score associated with each src-tgt match
            convert_from_pixels (bool): if true, input is in pixel coordinates and must be converted to metric
        Returns:
            R_tgt_src (torch.tensor): (b,3,3) rotation from src to tgt
            t_src_tgt_intgt (torch.tensor): (b,3,1) translation from tgt to src as measured in tgt
        """
        if src_coords.size(0) > tgt_coords.size(0):
            BW = src_coords.size(0)
            B = int(BW / self.window_size)
            kp_inds, _ = get_indices(B, self.window_size)
            src_coords = src_coords[kp_inds]
        assert (src_coords.size() == tgt_coords.size())
        B = src_coords.size(0)  # B x N x 2
        if convert_from_pixels:
            src_coords = convert_to_radar_frame(src_coords, self.config)
            tgt_coords = convert_to_radar_frame(tgt_coords, self.config)
        if src_coords.size(2) < 3:
            pad = 3 - src_coords.size(2)
            src_coords = F.pad(src_coords, [0, pad, 0, 0])
        if tgt_coords.size(2) < 3:
            pad = 3 - tgt_coords.size(2)
            tgt_coords = F.pad(tgt_coords, [0, pad, 0, 0])
        src_coords = src_coords.transpose(2, 1)  # B x 3 x N
        tgt_coords = tgt_coords.transpose(2, 1)

        # Compute weighted centroids
        w = torch.sum(weights, dim=2, keepdim=True) + 1e-4
        src_centroid = torch.sum(src_coords * weights, dim=2,
                                 keepdim=True) / w  # B x 3 x 1
        tgt_centroid = torch.sum(tgt_coords * weights, dim=2, keepdim=True) / w

        src_centered = src_coords - src_centroid  # B x 3 x N
        tgt_centered = tgt_coords - tgt_centroid

        W = torch.bmm(tgt_centered * weights, src_centered.transpose(
            2, 1)) / w  # B x 3 x 3

        try:
            U, _, V = torch.svd(W)
        except RuntimeError:  # torch.svd sometimes has convergence issues, this has yet to be patched.
            print(W)
            print('Adding turbulence to patch convergence issue')
            U, _, V = torch.svd(W + 1e-4 * W.mean() *
                                torch.rand(1, 3).to(self.gpuid))

        det_UV = torch.det(U) * torch.det(V)
        ones = torch.ones(B, 2).type_as(V)
        S = torch.diag_embed(torch.cat((ones, det_UV.unsqueeze(1)),
                                       dim=1))  # B x 3 x 3

        # Compute rotation and translation (T_tgt_src)
        R_tgt_src = torch.bmm(U, torch.bmm(S, V.transpose(2, 1)))  # B x 3 x 3
        t_tgt_src_insrc = src_centroid - torch.bmm(R_tgt_src.transpose(2, 1),
                                                   tgt_centroid)  # B x 3 x 1
        t_src_tgt_intgt = -R_tgt_src.bmm(t_tgt_src_insrc)

        return R_tgt_src, t_src_tgt_intgt
Beispiel #6
0
    def det_unique_single_double(self, input):
        """Computes the SD of single/double excitations

        The determinants of the single excitations
        are calculated from the ground state determinant and
        the ground state Slater matrices whith one column modified.
        See : Monte Carlo Methods in ab initio quantum chemistry
        B.L. Hammond, appendix B1


        Note : if the state on coonfigs are specified in order
        we end up with excitations that comes from a deep orbital, the resulting
        slater matrix has one column changed (with the new orbital) and several
        permutation. We therefore need to multiply the slater determinant
        by (-1)^nperm.


        .. math::

            MO = [ A | B ]
            det(Exc_{ij}) = (det(A) * A^{-1} * B)_{i,j}

        Args:
            input (torch.tensor): MO matrices nbatch x nelec x nmo

        """

        nbatch = input.shape[0]

        if not hasattr(self.exc_mask, 'index_unique_single_up'):
            self.exc_mask.get_index_unique_single()

        if not hasattr(self.exc_mask, 'index_unique_double_up'):
            self.exc_mask.get_index_unique_double()

        do_single = len(self.exc_mask.index_unique_single_up) != 0
        do_double = len(self.exc_mask.index_unique_double_up) != 0

        # occupied orbital matrix + det and inv on spin up
        Aup = input[:, :self.nup, :self.nup]
        detAup = torch.det(Aup)

        # occupied orbital matrix + det and inv on spin down
        Adown = input[:, self.nup:, :self.ndown]
        detAdown = torch.det(Adown)

        # store all the dets we need
        det_out_up = detAup.unsqueeze(-1).clone()
        det_out_down = detAdown.unsqueeze(-1).clone()

        # return the ground state
        if self.config_method == 'ground_state':
            return det_out_up, det_out_down

        # inverse of the
        invAup = torch.inverse(Aup)
        invAdown = torch.inverse(Adown)

        # virtual orbital matrices spin up/down
        Bup = input[:, :self.nup, self.nup:self.index_max_orb_up]
        Bdown = input[:, self.nup:,
                      self.ndown: self.index_max_orb_down]

        # compute the products of Ain and B
        mat_exc_up = (invAup @ Bup)
        mat_exc_down = (invAdown @ Bdown)

        if do_single:

            # determinant of the unique excitation spin up
            det_single_up = mat_exc_up.view(
                nbatch, -1)[:, self.exc_mask.index_unique_single_up]

            # determinant of the unique excitation spin down
            det_single_down = mat_exc_down.view(
                nbatch, -1)[:, self.exc_mask.index_unique_single_down]

            # multiply with ground state determinant
            # and account for permutation for deep excitation
            det_single_up = detAup.unsqueeze(-1) * \
                det_single_up.view(nbatch, -1)

            # multiply with ground state determinant
            # and account for permutation for deep excitation
            det_single_down = detAdown.unsqueeze(-1) * \
                det_single_down.view(nbatch, -1)

            # accumulate the dets
            det_out_up = torch.cat((det_out_up, det_single_up), dim=1)
            det_out_down = torch.cat(
                (det_out_down, det_single_down), dim=1)

        if do_double:

            # det of unique spin up double exc
            det_double_up = mat_exc_up.view(
                nbatch, -1)[:, self.exc_mask.index_unique_double_up]

            det_double_up = bdet2(
                det_double_up.view(nbatch, -1, 2, 2))

            det_double_up = detAup.unsqueeze(-1) * det_double_up

            # det of unique spin down double exc
            det_double_down = mat_exc_down.view(
                nbatch, -1)[:, self.exc_mask.index_unique_double_down]

            det_double_down = bdet2(
                det_double_down.view(nbatch, -1, 2, 2))

            det_double_down = detAdown.unsqueeze(-1) * det_double_down

            det_out_up = torch.cat((det_out_up, det_double_up), dim=1)
            det_out_down = torch.cat(
                (det_out_down, det_double_down), dim=1)

        return det_out_up, det_out_down
Beispiel #7
0
def det_keepdim(A):
    return torch.det(A).view(1, 1)
def compute_relative_pose_with_ransac(target_keypoints, source_keypoints):
    """
    :param target_keypoints: N * 2
    :param source_keypoints: N * 2
    :return: T_target_source_best: 4 * 4
             score: float
    """
    assert (target_keypoints.shape == source_keypoints.shape)
    num_matches = len(target_keypoints)
    n, k = 1000, 10
    if num_matches < k:
        return None, None

    target_keypoints = torch.Tensor(target_keypoints)
    source_keypoints = torch.Tensor(source_keypoints)

    selections = np.random.choice(num_matches, (n, k), replace=True)

    target_sub_keypoints = target_keypoints[selections]  # N * k * 2
    source_sub_keypoints = source_keypoints[selections]  # N * k * 2
    target_centers = target_sub_keypoints.mean(dim=1)  # N * 2
    source_centers = source_sub_keypoints.mean(dim=1)  # N * 2
    target_sub_keypoints_centered = target_sub_keypoints - target_centers.unsqueeze(
        1)
    source_sub_keypoints_centered = source_sub_keypoints - source_centers.unsqueeze(
        1)
    cov = source_sub_keypoints_centered.transpose(
        1, 2) @ target_sub_keypoints_centered
    u, s, v = torch.svd(cov)  # u: N*2*2, s: N*2, v: N*2*2

    v_neg = v.clone()
    v_neg[:, :, 1] *= -1

    rot_mats_neg = v_neg @ u.transpose(1, 2)
    rot_mats_pos = v @ u.transpose(1, 2)
    determinants = torch.det(rot_mats_pos)

    rot_mats_neg_list = [rot_mat_neg for rot_mat_neg in rot_mats_neg]
    rot_mats_pos_list = [rot_mat_neg for rot_mat_neg in rot_mats_pos]

    rot_mats_list = [
        rot_mat_pos if determinant > 0 else rot_mat_neg
        for (determinant, rot_mat_pos, rot_mat_neg
             ) in zip(determinants, rot_mats_pos_list, rot_mats_neg_list)
    ]
    rotations = torch.stack(rot_mats_list)  # N * 2 * 2
    translations = torch.einsum("nab,nb->na", -rotations,
                                source_centers) + target_centers  # N * 2
    diff = source_keypoints @ rotations.transpose(
        1, 2) + translations.unsqueeze(1) - target_keypoints
    distances_squared = torch.sum(diff * diff, dim=2)

    distance_tolerance = 0.5
    scores = (distances_squared < (distance_tolerance**2)).sum(dim=1)
    score = torch.max(scores)
    best_index = torch.argmax(scores)
    rotation = rotations[best_index]
    translation = translations[best_index]
    T_target_source = torch.cat((rotation, translation[..., None]), dim=1)
    T_target_source = torch.cat((T_target_source, torch.Tensor([[0, 0, 1]])),
                                dim=0)
    return T_target_source, score
Beispiel #9
0
def motion_synchronization_spectral(R,
                                    t,
                                    c: torch.Tensor = None,
                                    fallback_on_error: bool = False):
    n_view = len(R)
    n_batch = R[0][0].size(0)
    assert len(R) == n_view and len(R[0]) == n_view
    assert len(t) == n_view and len(t[0]) == n_view

    if c is None:
        c = torch.ones((n_batch, n_view, n_view), device=R[0][0].device)
        c -= torch.eye(n_view,
                       device=c.device).unsqueeze(0).repeat(n_batch, 1, 1)

    # Rotation Sync. (R_{ij} = R_{i0} * R_{j0}^T)
    L = []
    c_rowsum = (torch.sum(c, dim=-1) - torch.diagonal(c, dim1=-2, dim2=-1))
    for view_i in range(n_view):
        L_row = []
        for view_j in range(n_view):
            if view_i == view_j:
                L_row.append(
                    torch.eye(3, dtype=c.dtype,
                              device=c.device).unsqueeze(0).repeat(
                                  n_batch, 1, 1) *
                    c_rowsum[:, view_i].unsqueeze(1).unsqueeze(2))
            else:
                L_row.append(-c[:, view_i, view_j].unsqueeze(1).unsqueeze(2) *
                             R[view_i][view_j].transpose(-1, -2))
        L_row = torch.stack(L_row, dim=-2)
        L.append(L_row)
    L = torch.stack(L, dim=1)
    L = L.reshape((-1, n_view * 3, n_view * 3))

    # Solve for 3 smallest eigen vectors (using SVD)
    try:
        e, V = torch.symeig(L, eigenvectors=True)
    except RuntimeError:
        if fallback_on_error:
            final_R = torch.stack(R[0], dim=1)
            final_t = torch.stack(t[0], dim=1)
            return final_R, final_t
        else:
            raise

    V = V[..., :3].reshape(n_batch, n_view, 3, 3)
    detV = torch.det(V.detach().contiguous()).sum(dim=1)
    Vlc = V[..., -1] * detV.sign().unsqueeze(1).unsqueeze(2)
    V = torch.cat([V[..., :-1], Vlc.unsqueeze(-1)], dim=-1)

    V = V.reshape(n_batch * n_view, 3, 3)
    u, s, v = torch.svd(V, some=False, compute_uv=True)
    R_optimal = torch.bmm(u, v.transpose(1, 2)).reshape(n_batch, n_view, 3, 3)

    # Translation Sync.
    b_elements = []
    for view_i in range(n_view):
        b_row_elements = []
        for view_j in range(n_view):
            b_row_elements.append(
                -c[:, view_i, view_j].unsqueeze(1) * torch.einsum(
                    'bnm,bn->bm', R[view_i][view_j], t[view_i][view_j]))
        b_row_elements = sum(b_row_elements)
        b_elements.append(b_row_elements)
    b_elements = torch.stack(b_elements, dim=1)
    b_elements = b_elements.reshape((n_batch, n_view * 3))
    t_optimal, _ = torch.solve(b_elements.unsqueeze(-1), L)
    t_optimal = t_optimal.reshape(n_batch, n_view, 3)
    return R_optimal, t_optimal
Beispiel #10
0
    def __iter__(self):
        for epi in range(self.episodes_per_epoch):
            batch = []

            for task in range(self.num_tasks):
                if self.fixed_tasks is None:
                    # Get random classes with (num_s_candidates) number support sets
                    # support_candidates: List[Iterable[num_s_candidates]]  [[num_sample * k],...,[num_sample * k]]
                    support_candidates = []
                    episode_classes_set = []
                    for i in range(self.num_s_candidates):
                        episode_classes = list(
                            np.random.choice(
                                self.dataset.df['class_id'].unique(),
                                size=self.k,
                                replace=False))
                        episode_classes_set.append(episode_classes)

                        # [1(num_sample), 2(num_sample), ... , k(num_sample)]: len <num_sample * k>
                        # TODO: each num_sample -> need to average
                        episode_id = []
                        for j in episode_classes:
                            tmp_df = self.dataset.df[
                                self.dataset.df['class_id'] == j].sample(
                                    self.num_sample)
                            episode_id.extend(tmp_df['id'])
                        support_candidates.append(episode_id)
                else:
                    raise (ValueError(
                        'cant do fixed_tasks with importance sampling.'))

                self.model.eval()
                diversity = []
                # diversity: List[Iterable[num_s_candidates]]
                for i in range(self.num_s_candidates):
                    # get train-support sets
                    # (1) "num_sample" number of samples for each label in train dataset (support set)
                    support_id = support_candidates[i]
                    support_train_items = list(
                        map(self.dataset.__getitem__, support_id))
                    support_train_features = list(zip(*support_train_items))[0]
                    support_train_features = torch.stack(
                        support_train_features)
                    support_train_features = support_train_features.double(
                    ).cuda()

                    train_embeddings = self.model(
                        support_train_features)  # shape: (num_sample*k, 1600)

                    # (2) get feature for each label (train support set) using mean vector by manifold space
                    # DO: Each num_sample -> average
                    average_embeddings = []
                    for i in range(self.k):
                        if self.num_sample == 1:
                            tmp = torch.mean(
                                train_embeddings[self.num_sample *
                                                 i:self.num_sample * (i + 1)],
                                dim=0)
                        else:
                            tmp = train_embeddings[i]
                        average_embeddings.append(tmp)

                    average_embeddings = torch.stack(
                        average_embeddings)  # shape: (k, 1600)

                    norm_embedding = average_embeddings / (
                        average_embeddings.pow(2).sum(
                            dim=1, keepdim=True).sqrt() + EPSILON)
                    trans_dot_norm = torch.mm(norm_embedding,
                                              norm_embedding.transpose(
                                                  0, 1))  # shape: (k-1, k-1)
                    determinant = torch.det(trans_dot_norm).item()
                    determinant = determinant**0.5  # square root
                    diversity.append(determinant)

                diversity = np.array(diversity)
                diversity[np.isnan(diversity)] = EPSILON
                '''
                # applied to softmax with temperature for a half of entire iterations (T: 20 -> 1)
                self.i_iter += 1
                
                temp_change_iteration = (self.total_epochs*self.episodes_per_epoch) // 3
                if temp_change_iteration > self.i_iter:
                    temperature = self.init_temperature - (self.init_temperature - 1)*(self.i_iter/temp_change_iteration)
                else:
                    temperature = 1

                if self.is_diversity:
                    supports_sampling_rate = softmax_with_temperature(diversity, temperature)
                else:
                    # similarity
                    supports_sampling_rate = softmax_with_temperature(-diversity, temperature)
                '''

                if self.is_diversity:
                    supports_sampling_rate = softmax_with_temperature(
                        diversity, 1.0)
                else:
                    similarity = (-1 * diversity)
                    supports_sampling_rate = softmax_with_temperature(
                        similarity, 1.0)

                support_choice = np.random.choice(
                    list(range(self.num_s_candidates)),
                    size=1,
                    replace=False,
                    p=supports_sampling_rate.astype(np.float64))

                support_candidates_id = support_candidates[support_choice[0]]
                sampling_support_classes = episode_classes_set[
                    support_choice[0]]

                df = self.dataset.df[self.dataset.df['class_id'].isin(
                    sampling_support_classes)]
                support_k = {k: None for k in sampling_support_classes}

                # Select support examples
                batch.extend(support_candidates_id)

                # self.num_sample == self.n
                for k in sampling_support_classes:
                    support_k[k] = support_candidates_id[k *
                                                         self.num_sample:(k +
                                                                          1) *
                                                         self.num_sample]

                # Select Query examples
                for k in sampling_support_classes:
                    query = df[(df['class_id'] == k)
                               & (~df['id'].isin(support_k[k]))].sample(self.q)
                    for i, q in query.iterrows():
                        batch.append(q['id'])

            yield np.stack(batch)
    def calculate_log_normal(self, vector, mean, covariances):

        return 0.5 * (
            (-torch.log(torch.det(covariances))) -
            torch.mm(torch.mm((vector - mean), torch.inverse(covariances)),
                     torch.t(vector - mean)))
def dmi(target, pred):
    # L_DMI of https://arxiv.org/pdf/1909.03388.pdf
    # mat = torch.mm(target.T, pred) / target.shape[0]  # normalizing makes the determinant too small
    mat = torch.mm(target.T, pred)
    return -torch.log(torch.abs(torch.det(mat)) + 0.001)
def QuantizedWeight(x, n, nbit=2, training=False):
    """
    Quantize weight.
    Args:
        x (torch.Tensor): a 4D tensor. [K x K x iC x oC] -> [oC x iC x K x K]
            Must have known number of channels, but can have other unknown dimensions.
        n (int or double): variance of weight initialization.
        nbit (int): number of bits of quantized weight. Defaults to 2.
        training (bool):
    Returns:
        torch.Tensor with attribute `variables`.
    Variable Names:
    * ``basis``: basis of quantized weight.
    Note:
        About multi-GPU training: moving averages across GPUs are not aggregated.
        Batch statistics are computed by main training tower. This is consistent with most frameworks.
    """
    oc, ic, k1, k2 = x.shape
    device = x.device

    init_basis = []
    base = NORM_PPF_0_75 * ((2. / n) ** 0.5) / (2 ** (nbit - 1))
    for j in range(nbit):
        init_basis.append([(2 ** j) * base for i in range(oc)])

    num_levels = 2 ** nbit
    delta = EPS

    # initialize level multiplier
    # binary code of each level:
    # shape: [num_levels, nbit]
    init_level_multiplier = []
    for i in range(num_levels):
        level_multiplier_i = [0. for j in range(nbit)]
        level_number = i
        for j in range(nbit):
            binary_code = level_number % 2
            if binary_code == 0:
                binary_code = -1
            level_multiplier_i[j] = float(binary_code)
            level_number = level_number // 2
        init_level_multiplier.append(level_multiplier_i)

    # initialize threshold multiplier
    # shape: [num_levels-1, num_levels]
    # [[0,0,0,0,0,0,0.5,0.5]
    #  [0,0,0,0,0,0.5,0.5,0,]
    #  [0,0,0,0,0.5,0.5,0,0,]
    #  ...
    #  [0.5,0.5,0,0,0,0,0,0,]]
    init_thrs_multiplier = []
    for i in range(1, num_levels):
        thrs_multiplier_i = [0. for j in range(num_levels)]
        thrs_multiplier_i[i - 1] = 0.5
        thrs_multiplier_i[i] = 0.5
        init_thrs_multiplier.append(thrs_multiplier_i)

    # [nbit, oc]
    basis = torch.tensor(init_basis, dtype=torch.float32, requires_grad=False)
    # [2**nbit, nbit] or [num_levels, nbit]
    level_codes = torch.tensor(init_level_multiplier)
    # [num_levels-1, num_levels]
    thrs_multiplier = torch.tensor(init_thrs_multiplier)

    # [oC x iC x K x K] -> [K x K x iC x oC]
    xp = x.permute((3, 2, 1, 0))

    N = 3
    # training
    if training:
        for _ in torch.arange(N):
            # calculate levels and sort [2**nbit, oc] or [num_levels, oc]
            levels = torch.matmul(level_codes, basis)
            levels, sort_id = torch.sort(levels, 0)

            # calculate threshold
            # [num_levels-1, oc]
            thrs = torch.matmul(thrs_multiplier, levels)

            # calculate level codes per channel
            # ix:sort_id [num_levels, oc], iy: torch.arange(nbit)
            # level_codes = [num_levels, nbit]
            # level_codes_channelwise [num_levels, oc, nbit]
            for oc_idx in torch.arange(oc):
                if oc_idx == 0:
                    level_codes_t = level_codes[
                        torch.meshgrid(sort_id[:, oc_idx], torch.arange(nbit, device=sort_id.device))].unsqueeze(1)
                    level_codes_channelwise = level_codes_t
                else:
                    level_codes_t = level_codes[
                        torch.meshgrid(sort_id[:, oc_idx], torch.arange(nbit, device=sort_id.device))].unsqueeze(1)
                    level_codes_channelwise = torch.cat((level_codes_channelwise, level_codes_t), 1)

            # calculate output y and its binary code
            # y [K, K, iC, oC]
            # bits_y [K x K x iC, oC, nbit]
            reshape_x = torch.reshape(xp, [-1, oc])
            y = torch.zeros_like(xp) + levels[0]  # output
            zero_y = torch.zeros_like(xp)
            bits_y = torch.full([reshape_x.shape[0], oc, nbit], -1., device=device)
            zero_bits_y = torch.zeros_like(bits_y)

            # [K x K x iC x oC] [1, oC]
            for i in torch.arange(num_levels - 1):
                g = torch.ge(xp, thrs[i])
                # [K, K, iC, oC] + [1, oC], [K, K, iC, oC] => [K, K, iC, oC]
                y = torch.where(g, zero_y + levels[i + 1], y)
                # [K x K x iC, oC, nbit]
                bits_y = torch.where(g.view(-1, oc, 1), zero_bits_y + level_codes_channelwise[i + 1], bits_y)

            # calculate BTxB
            # [oC, nbit, K x K x iC] x [oC, K x K x iC, nbit] => [oC, nbit, nbit]
            BTxB = torch.matmul(bits_y.permute(1, 2, 0), bits_y.permute(1, 0, 2)) + delta * torch.eye(nbit,
                                                                                                      device=device)
            # calculate inverse of BTxB
            # [oC, nbit, nbit]
            if nbit > 2:
                BTxB_inv = torch.inverse(BTxB)
            elif nbit == 2:
                det = torch.det(BTxB)
                BTxB_inv = torch.stack((BTxB[:, 1, 1], -BTxB[:, 0, 1], -BTxB[:, 1, 0], BTxB[:, 0, 0]),
                                       1).view(OC, nbit, nbit) / det.unsqueeze(-1).unsqueeze(-1)
            elif nbit == 1:
                BTxB_inv = 1 / BTxB
            else:
                BTxB_inv = None

            # calculate BTxX
            # bits_y [K x K x iC, oc, nbit] reshape_x [K x K x iC, oC]
            # [oC, nbit, K x K x iC] [oC, K x K x iC, 1] => [oC, nbit, 1]
            BTxX = torch.matmul(bits_y.permute(1, 2, 0), reshape_x.permute(1, 0).unsqueeze(-1))
            BTxX = BTxX + (delta * basis.permute(1, 0).unsqueeze(-1))  # + basis

            # calculate new basis
            # BTxB_inv: [oC, nbit, nbit] BTxX: [oC, nbit, 1]
            # [oC, nbit, nbit] x [oC, nbit, 1] => [oC, nbit, 1] => [nbit, oC]
            new_basis = torch.matmul(BTxB_inv, BTxX).squeeze(-1).permute(1, 0)
            # print(BTxB_inv.shape,BTxX.shape)
            # print(new_basis.shape)

            # create moving averages op
            basis -= (1 - MOVING_AVERAGES_FACTOR) * (basis - new_basis)
            # print("\nbasis:\n", basis)

    # calculate levels and sort [2**nbit, oc] or [num_levels, oc]
    levels = torch.matmul(level_codes, basis)
    levels, sort_id = torch.sort(levels, 0)

    # calculate threshold
    # [num_levels-1, oc]
    thrs = torch.matmul(thrs_multiplier, levels)

    # calculate level codes per channel
    # ix:sort_id [num_levels, oc], iy: torch.arange(nbit)
    # level_codes = [num_levels, nbit]
    # level_codes_channelwise [num_levels, oc, nbit]
    for oc_idx in torch.arange(oc):
        if oc_idx == 0:
            level_codes_t = level_codes[
                torch.meshgrid(sort_id[:, oc_idx], torch.arange(nbit, device=sort_id.device))].unsqueeze(1)
            level_codes_channelwise = level_codes_t
        else:
            level_codes_t = level_codes[
                torch.meshgrid(sort_id[:, oc_idx], torch.arange(nbit, device=sort_id.device))].unsqueeze(1)
            level_codes_channelwise = torch.cat((level_codes_channelwise, level_codes_t), 1)

    # calculate output y and its binary code
    # y [K, K, iC, oC]
    # bits_y [K x K x iC, oC, nbit]
    reshape_x = torch.reshape(xp, [-1, oc])
    y = torch.zeros_like(xp) + levels[0]  # output
    zero_y = torch.zeros_like(xp)
    bits_y = torch.full([reshape_x.shape[0], oc, nbit], -1., device=device)
    zero_bits_y = torch.zeros_like(bits_y)

    # [K x K x iC x oC] [1, oC]
    for i in torch.arange(num_levels - 1):
        g = torch.ge(xp, thrs[i])
        # [K, K, iC, oC] + [1, oC], [K, K, iC, oC] => [K, K, iC, oC]
        y = torch.where(g, zero_y + levels[i + 1], y)
        # [K x K x iC, oC, nbit]
        bits_y = torch.where(g.view(-1, oc, 1), zero_bits_y + level_codes_channelwise[i + 1], bits_y)

    return y.permute(3, 2, 1, 0), levels.permute(1, 0), thrs.permute(1, 0)
Beispiel #14
0
 def forward(self, *input):
     src_embedding = input[0]
     tgt_embedding = input[1]
     src = input[2]
     tgt = input[3]
     batch_size, d_k, src_num_points = src_embedding.size()
     tgt_num_points = tgt.shape[2]
     # temperature = input[4].view(batch_size, 1, 1)
     # (bs, np, np)
     dists = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
     # affinity = dists / temperature
     affinity = dists
     log_perm_matrix = self.sinkhorn(affinity, n_iters=5)
     # (bs, np, np)
     perm_matrix = torch.exp(log_perm_matrix)
     perm_matrix_norm = perm_matrix / (torch.sum(perm_matrix, dim=2, keepdim=True) + 1e-8)
     # (bs, 3, np)
     # weighted_tgt = torch.matmul(tgt, perm_matrix_norm.transpose(2, 1).contiguous())
     # (bs, np, 1)
     # weights = torch.max(perm_matrix, dim=-1, keepdim=True)[0]
     # (bs, np)
     weights_row = torch.max(perm_matrix, dim=-1)[1]
     # (bs, np)
     weights_value, weights_col = torch.max(perm_matrix, dim=-2)
     tgt_idx = torch.arange(tgt_num_points).cuda()
     R = []
     src_centroid_array = []
     tgt_centroid_array = []
     for bs in range(batch_size):
         a = weights_row[bs, weights_col[bs]] == tgt_idx
         idx = torch.nonzero(a).squeeze(-1)
         # (3, k)
         bji_tgt = torch.matmul(tgt[bs, :, :], perm_matrix_norm[bs, weights_col[bs, idx], :].transpose(1, 0).contiguous())
         bji_src = src[bs, :, weights_col[bs, idx]]
         # (k, 1)
         weights = weights_value[bs, idx].unsqueeze(-1)
         # (k, 1)
         weights_norm = weights / (torch.sum(weights, dim=0, keepdim=True) + 1e-8)
         # (1, k)
         weights_norm = weights_norm.transpose(1, 0).contiguous()
         # (3, 1)
         src_centroid = torch.sum(torch.mul(bji_src, weights_norm), dim=1, keepdim=True)
         tgt_centroid = torch.sum(torch.mul(bji_tgt, weights_norm), dim=1, keepdim=True)
         src_centroid_array.append(src_centroid)
         tgt_centroid_array.append(tgt_centroid)
         src_centered = bji_src - src_centroid
         src_corr_centered = bji_tgt - tgt_centroid
         src_corr_centered = torch.mul(src_corr_centered, weights_norm)
         H = torch.matmul(src_centered, src_corr_centered.transpose(1, 0).contiguous()).cpu()
         try:
             u, s, v = torch.svd(H)
         except:
             print(H)
         r = torch.matmul(v, u.transpose(1, 0)).contiguous()
         r_det = torch.det(r).item()
         diag = torch.from_numpy(np.array([[1.0, 0, 0],
                                           [0, 1.0, 0],
                                           [0, 0, r_det]]).astype('float32')).to(v.device)
         r = torch.matmul(torch.matmul(v, diag), u.transpose(1, 0)).contiguous()
         R.append(r)
     # 必须保证关键点数量一样
     # src = torch.stack(bji_tgt_list, dim=0)
     # weighted_tgt = torch.stack(bji_src_list, dim=0)
     # weights = torch.stack(weights_list, dim=0).unsqueeze(-1)
     # weights_zeros = torch.zeros_like(weights)
     # (bs, 3, 1)
     src_centroid = torch.stack(src_centroid_array, dim=0).cuda()
     tgt_centroid = torch.stack(tgt_centroid_array, dim=0).cuda()
     R = torch.stack(R, dim=0).cuda()
     t = torch.matmul(-R, src_centroid) + tgt_centroid
     return R, t.view(batch_size, 3), perm_matrix_norm
Beispiel #15
0
 def decision_function(z):
     q_phi = torch.exp((z - mu) ** 2 / (2 * logvar.exp()))
     riem_measure = torch.det(jacobian(decoder, z))
     return q_phi * riem_measure
Beispiel #16
0
 def det(self):
     return torch.det(self)
Beispiel #17
0
def corresponding_points_alignment(
    X: Union[torch.Tensor, "Pointclouds"],
    Y: Union[torch.Tensor, "Pointclouds"],
    weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
    estimate_scale: bool = False,
    allow_reflection: bool = False,
    eps: float = 1e-9,
) -> SimilarityTransform:
    """
    Finds a similarity transformation (rotation `R`, translation `T`
    and optionally scale `s`)  between two given sets of corresponding
    `d`-dimensional points `X` and `Y` such that:

    `s[i] X[i] R[i] + T[i] = Y[i]`,

    for all batch indexes `i` in the least squares sense.

    The algorithm is also known as Umeyama [1].

    Args:
        **X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
            or a `Pointclouds` object.
        **Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
            or a `Pointclouds` object.
        **weights**: Batch of non-negative weights of
            shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
            tensors that may have different shapes; in that case, the length of
            i-th tensor should be equal to the number of points in X_i and Y_i.
            Passing `None` means uniform weights.
        **estimate_scale**: If `True`, also estimates a scaling component `s`
            of the transformation. Otherwise assumes an identity
            scale and returns a tensor of ones.
        **allow_reflection**: If `True`, allows the algorithm to return `R`
            which is orthonormal but has determinant==-1.
        **eps**: A scalar for clamping to avoid dividing by zero. Active for the
            code that estimates the output scale `s`.

    Returns:
        3-element named tuple `SimilarityTransform` containing
        - **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
        - **T**: Batch of translations of shape `(minibatch, d)`.
        - **s**: batch of scaling factors of shape `(minibatch, )`.

    References:
        [1] Shinji Umeyama: Least-Suqares Estimation of
        Transformation Parameters Between Two Point Patterns
    """

    # make sure we convert input Pointclouds structures to tensors
    Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
    Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)

    if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
        raise ValueError(
            "Point sets X and Y have to have the same \
            number of batches, points and dimensions."
        )
    if weights is not None:
        if isinstance(weights, list):
            if any(np != w.shape[0] for np, w in zip(num_points, weights)):
                raise ValueError(
                    "number of weights should equal to the "
                    + "number of points in the point cloud."
                )
            weights = [w[..., None] for w in weights]
            weights = strutil.list_to_padded(weights)[..., 0]

        if Xt.shape[:2] != weights.shape:
            raise ValueError("weights should have the same first two dimensions as X.")

    b, n, dim = Xt.shape

    if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
        # in case we got Pointclouds as input, mask the unused entries in Xc, Yc
        mask = (
            torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
            < num_points[:, None]
        ).type_as(Xt)
        weights = mask if weights is None else mask * weights.type_as(Xt)

    # compute the centroids of the point sets
    Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
    Ymu = oputil.wmean(Yt, weight=weights, eps=eps)

    # mean-center the point sets
    Xc = Xt - Xmu
    Yc = Yt - Ymu

    total_weight = torch.clamp(num_points, 1)
    # special handling for heterogeneous point clouds and/or input weights
    if weights is not None:
        Xc *= weights[:, :, None]
        Yc *= weights[:, :, None]
        total_weight = torch.clamp(weights.sum(1), eps)

    if (num_points < (dim + 1)).any():
        warnings.warn(
            "The size of one of the point clouds is <= dim+1. "
            + "corresponding_points_alignment cannot return a unique rotation."
        )

    # compute the covariance XYcov between the point sets Xc, Yc
    XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
    XYcov = XYcov / total_weight[:, None, None]

    # decompose the covariance matrix XYcov
    U, S, V = torch.svd(XYcov)

    # catch ambiguous rotation by checking the magnitude of singular values
    if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not (
        num_points < (dim + 1)
    ).any():
        warnings.warn(
            "Excessively low rank of "
            + "cross-correlation between aligned point clouds. "
            + "corresponding_points_alignment cannot return a unique rotation."
        )

    # identity matrix used for fixing reflections
    E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)

    if not allow_reflection:
        # reflection test:
        #   checks whether the estimated rotation has det==1,
        #   if not, finds the nearest rotation s.t. det==1 by
        #   flipping the sign of the last singular vector U
        R_test = torch.bmm(U, V.transpose(2, 1))
        E[:, -1, -1] = torch.det(R_test)

    # find the rotation matrix by composing U and V again
    R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))

    if estimate_scale:
        # estimate the scaling component of the transformation
        trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
        Xcov = (Xc * Xc).sum((1, 2)) / total_weight

        # the scaling component
        s = trace_ES / torch.clamp(Xcov, eps)

        # translation component
        T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
    else:
        # translation component
        T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]

        # unit scaling since we do not estimate scale
        s = T.new_ones(b)

    return SimilarityTransform(R, T, s)
Beispiel #18
0
    def forward(self, gt_pose, avg_vector):
        # reshape
        gt_pose = gt_pose.reshape(3, 1)

        # get the delta Observation
        mean_xyt, var_xyt = self.getDeltaMeanVar(avg_vector)  # on rad

        Z = self.getGlobalObservation(mean_xyt).to(self.device)

        # P_z = np.diag([0.5,
        #      0.5,
        #      np.deg2rad(1.0)]) ** 2 # cov

        P_z = torch.diag(var_xyt).to(self.device, dtype=torch.float)  # cov

        Jacob_H = self.getJacobianH().to(self.device)

        S = torch.mm(torch.mm(Jacob_H, self.P_cov_pred), Jacob_H.t()) + P_z
        S = S.to(self.device, dtype=torch.float)

        K = torch.mm(torch.mm(self.P_cov_pred, Jacob_H.t()),
                     torch.inverse(S)).to(self.device, dtype=torch.float)

        # print('Pred_xyt: ', self.est_xyt_pred)
        # print('Global-Z: ', Z)
        # print('P_z: ')
        # print(P_z)
        # print('K_gain: ')
        # print(K)

        Delta = (Z - self.est_xyt_pred).to(self.device, dtype=torch.float)
        Delta[2] = self.wrapToPi(Delta[2])

        # print('Delta: ', Delta)
        # print('K@(Z-X): ', K @ Delta)

        self.est_xyt = self.est_xyt_pred + torch.mm(K, Delta)
        self.est_xyt[2] = self.wrapToPi(self.est_xyt[2])

        Iden = torch.eye(3).to(self.device, dtype=torch.float)

        self.P_cov = torch.mm((Iden - torch.mm(K, Jacob_H)),
                              self.P_cov_pred).to(self.device,
                                                  dtype=torch.float)

        #
        # LOSS
        # error_loss
        error = (gt_pose - self.est_xyt).reshape(3, 1).to(self.device,
                                                          dtype=torch.float)
        error[2] = self.wrapToPi(error[2])  # deBug
        error_loss = torch.mm(torch.mm(error.t(), torch.inverse(self.P_cov)),
                              error)

        # cov_loss
        cov_loss = torch.det(self.P_cov)

        loss = error_loss + self.lamda * cov_loss

        print('Error-loss: ', error_loss, ' Cov-loss: ', cov_loss, ' Loss: ',
              loss)

        return loss, self.est_xyt
            
            # print(f'grad_B: {grad_B}')
            # print(f'grad_B: {torch.norm(grad_B, 1)}')
            

            # Update parameters in time step t-H with saved gradients 
            upd_B = binder.update_binding_entries_(Bs[0], grad_B, at_learning_rate, bm_momentum)

            # Compare binding matrix to ideal matrix
            c_bm = binder.compute_binding_matrix(upd_B)
            mat_loss = evaluator.FBE(c_bm, ideal_binding)
            bm_losses.append(mat_loss)
            print(f'loss of binding matrix (FBE): {mat_loss}')

            # Compute determinante of binding matrix
            det = torch.det(c_bm)
            bm_dets.append(det)
            print(f'determinante of binding matrix: {det}')
            
            
            # Zero out gradients for all parameters in all time steps of tuning horizon
            for i in range(tuning_length+1):
                for j in range(num_input_features):
                    for k in range(num_input_features):
                        Bs[i][j][k].requires_grad = False
                        Bs[i][j][k].grad.data.zero_()

            # Update all parameters for all time steps 
            for i in range(tuning_length+1):
                entries = []
                for j in range(num_input_features):
Beispiel #20
0
def update_EM(X, K, gamma, A, pi, mu, sigma_sqr, threshold=5e-5,
       A_mode='GA', grad_mode='GA',
       max_em_steps=30, n_gd_steps=20):

  if type(X) is not torch.Tensor:
    X = torch.tensor(X)
  X = X.type(DTYPE).to(device)

  N, D = X.shape

  END = lambda dA, dsigma_sqr: (dA + dsigma_sqr) < threshold

  Y = None
  niters = 0
  dA, dsigma_sqr = 10, 10
  ret_time = {'E':[], 'obj':[]}
  grad_norms, objs= [], []

  if A_mode == 'random':
    A = ortho_group.rvs(D)
    A = to_tensor(A)
  elif A_mode == 'ICA':
    cov = X.T.matmul(X) / len(X)
    cnt = 0
    n_tries = 20
    while cnt < n_tries:
      try:
        ica = FastICA()
        _ = ica.fit_transform(X.cpu())
        Aorig = ica.mixing_

        # avoid numerical instability
        U, ss, V = np.linalg.svd(Aorig)
        ss /= ss[0]
        ss[ss < SINGULAR_SMALL] = SINGULAR_SMALL
        Aorig = (U * ss).dot(V)

        A = np.linalg.inv(Aorig)
        _, ss, _ = np.linalg.svd(A)
        A = to_tensor(A / ss[0])
        cnt = 2*n_tries
      except:
        cnt += 1
    if cnt != 2*n_tries:
      print('ICA failed. Use random.')
      A = to_tensor(ortho_group.rvs(D))

  while (not END(dA, dsigma_sqr)) and niters < max_em_steps:
    niters += 1
    A_prev, sigma_sqr_prev = A.clone(), sigma_sqr.clone()
    objs += [],

    if TIME: e_start = time()
    Y, w, w_sumN, w_sumNK = E(X, A, pi, mu, sigma_sqr, Y=Y)
    if TIME: ret_time['E'] += time() - e_start,

    # M-step
    if A_mode == 'ICA' or A_mode == 'None':
      pi, mu, sigma_sqr = update_pi_mu_sigma(X, A, w, w_sumN, w_sumNK)
      obj = get_objetive(X, A, pi, mu, sigma_sqr, w)
      objs[-1] += obj,

    if A_mode == 'CF': # gradient ascent
      if CHECK_OBJ:
        objs[-1] += get_objetive(X, A, pi, mu, sigma_sqr, w),

      for i in range(n_gd_steps):
        cf_start = time()
        if VERBOSE: print(A.view(-1))
        
        pi, mu, sigma_sqr = update_pi_mu_sigma(X, A, w, w_sumN, w_sumNK)

        if TIME: a_start = time()
        if grad_mode == 'CF1':
          A = set_grad_zero(X, A, w, mu, sigma_sqr)
          A = A.T
        elif grad_mode == 'CF2':
          cofs = get_cofactors(A)
          det = torch.det(A)
          if det < 0: # TODO: ignore neg det for now
            cofs = cofs * -1

          newA = A.clone()
          for i in range(D):
            for j in range(D):
              t1 = (w[:, i] * X[:,j,None]**2 / sigma_sqr[i]).sum() / N
              diff = (Y[i] - A[i,j] * X[:, j])[:, None] - mu[i]
              t2 = (w[:, i] * X[:,j,None] * diff / sigma_sqr[i]).sum() / N
              c1 = t1 * cofs[i,j]
              c2 = t1 * (det - A[i,j]*cofs[i,j]) + t2 * cofs[i,j]
              c3 = t2 * (det - A[i,j]*cofs[i,j]) - cofs[i,j]
              inner = c2**2 - 4*c1*c3
              if inner < 0:
                print('Problme at solving for A[{},{}]: no real sol.'.format(i,j))
                pdb.set_trace()
              if c1 == 0:
                sol = - c3 / c2
              else:
                sol = (inner**0.5 - c2) / (2*c1)
              if False:
                # check whether obj gets improved with each updated entry of A
                curr_A = newA.clone()
                curr_A[i,j] = sol
                curr_obj = get_objetive(X, curr_A, pi, mu, sigma_sqr, w)
              newA[i,j] = sol
          A = newA.double()

        # avoid numerical instability
        U, ss, V = torch.svd(A)
        ss = ss / ss[0]
        ss[ss < SINGULAR_SMALL] = SINGULAR_SMALL
        A = (U * ss).matmul(V)

        if TIME:
          if 'A' not in ret_time: ret_time['A'] = []
          ret_time['A'] += time() - a_start,
          if 'CF' not in ret_time: ret_time['CF'] = []
          ret_time['CF'] += time() - cf_start,

        if CHECK_OBJ:
          if TIME: obj_start = time()
          obj = get_objetive(X, A, pi, mu, sigma_sqr, w)
          if TIME: ret_time['obj'] += time() - obj_start,
          objs[-1] += obj,
          if VERBOSE:
            print('iter {}: obj= {:.5f}'.format(i, obj))
        # pdb.set_trace()
      # pdb.set_trace()

    if A_mode == 'GA': # gradient ascent
      if CHECK_OBJ:
        objs[-1] += get_objetive(X, A, pi, mu, sigma_sqr, w),

      for i in range(n_gd_steps):
        ga_start = time()
        if VERBOSE: print(A.view(-1))
        
        pi, mu, sigma_sqr = update_pi_mu_sigma(X, A, w, w_sumN, w_sumNK)

        if TIME: a_start = time()
        # gradient steps
        grad, y_time = get_grad(X, A, w, mu, sigma_sqr)
        if TIME:
          if 'Y' not in ret_time:
            ret_time['Y'] = []
          ret_time['Y'] += y_time,
        if grad_mode == 'BTLS':
          # backtracking line search
          if TIME: obj_start = time()
          obj = get_objetive(X, A, pi, mu, sigma_sqr, w)
          if TIME: ret_time['obj'] += time() - obj_start,

          beta, t, flag = 0.6, 1, True
          gnorm = torch.norm(grad)
          n_iter, ITER_LIM = 0, 10
          while flag and n_iter < ITER_LIM:
            n_iter += 1
            Ap = A + t * grad
            _, ss, _ = torch.svd(Ap)
            Ap /= ss[0]
            if TIME: obj_start = time()
            obj_p = get_objetive(X, Ap, pi, mu, sigma_sqr, w)
            if TIME: ret_time['obj'] += time() - obj_start,
            t *= beta
            base = obj - 0.5 * t * gnorm
            flag = obj_p < base
          gamma = t
          ret_time['btls_nIters'] += n_iter,
        elif grad_mode == 'perturb':
          # perturb
          perturb = A.std() * 0.1 * torch.randn(A.shape).type(DTYPE).to(device)
          perturbed = A + perturb
          perturbed_grad, _ = get_grad(X, perturbed, w, mu, sigma_sqr)

          grad_diff = torch.norm(grad - perturbed_grad)
          gamma = 1 /(EPS_GRAD + grad_diff) * 0.03

        grad_norms += torch.norm(grad).item(),
        A += gamma * grad

        _, ss, _ = torch.svd(A)
        A /= ss[0]

        if TIME:
          if 'A' not in ret_time: ret_time['A'] = []
          ret_time['A'] += time() - a_start,
          if 'GA' not in ret_time: ret_time['GA'] = []
          ret_time['GA'] += time() - ga_start,

        if CHECK_OBJ:
          if TIME: obj_start = time()
          obj = get_objetive(X, A, pi, mu, sigma_sqr, w)
          if TIME: ret_time['obj'] += time() - obj_start,
          objs[-1] += obj,
          if VERBOSE:
            print('iter {}: obj= {:.5f}'.format(i, obj))
        # pdb.set_trace()
      # pdb.set_trace()

  if VERBOSE:
    print('#{}: dA={:.3e} / dsigma_sqr={:.3e}'.format(niters, dA, dsigma_sqr))
    print('A:', A.view(-1))

  if TIME:
    for key in ret_time:
      ret_time[key] = np.array(ret_time[key]) if ret_time[key] else 0

  # pdb.set_trace()
  return A, pi, mu, sigma_sqr, grad_norms, objs, ret_time 
 def test_det(self):
     x = torch.randn(2, 3, 5, 5, device=torch.device("cpu"))
     self.assertONNX(lambda x: torch.det(x), x, opset_version=11)
     self.assertONNX(lambda x: torch.linalg.det(x), x, opset_version=11)
Beispiel #22
0
 def forward(self, *input):
     src_embedding = input[0]
     tgt_embedding = input[1]
     src = input[2]
     tgt = input[3]
     batch_size, d_k, num_points_k = src_embedding.size()
     num_points = tgt.shape[2]
     temperature = input[4].view(batch_size, 1, 1)
     # (bs, np, np)
     dists = torch.matmul(
         src_embedding.transpose(2, 1).contiguous(),
         tgt_embedding) / math.sqrt(d_k)
     affinity = dists
     # affinity = dists
     log_perm_matrix = self.sinkhorn(affinity, n_iters=5)
     # (bs, np, np)
     perm_matrix = torch.exp(log_perm_matrix)
     perm_matrix_norm = perm_matrix / (
         torch.sum(perm_matrix, dim=2, keepdim=True) + 1e-8)
     # (bs, 3, np)
     src_corr = torch.matmul(tgt,
                             perm_matrix_norm.transpose(2, 1).contiguous())
     # (bs, np, 1) 感觉待采样的权重应该取最大值
     weights = torch.sum(perm_matrix, dim=-1, keepdim=True)
     # weights_zeros = torch.zeros_like(weights)
     # # 舍弃后一半的权重 (bs,)
     # weights_median = torch.median(weights, dim=1)[0].squeeze(-1)
     # for i in range(batch_size):
     #     weights[i] = torch.where(weights[i] > weights_median[i], weights[i], weights_zeros[i])
     # (bs, np, 64)
     x = self.relu1(self.linear1(weights))
     # (bs, np, 1)
     x = self.linear2(x)
     x = torch.sigmoid(x)
     x = torch.log(x + 1e-8)
     # (bs, np, 1)
     corr_scores = x.repeat(1, self.n_keypoints, 1)
     temperature = temperature.view(batch_size, 1)
     corr_scores = corr_scores.view(batch_size * self.n_keypoints,
                                    num_points)
     temperature = temperature.repeat(1, self.n_keypoints, 1).view(-1, 1)
     corr_scores = F.gumbel_softmax(corr_scores, tau=temperature, hard=True)
     corr_scores = corr_scores.view(batch_size, self.n_keypoints,
                                    num_points)
     src_k = torch.matmul(corr_scores,
                          src.transpose(2, 1).contiguous()).transpose(
                              2, 1).contiguous()
     src_corr_k = torch.matmul(
         corr_scores,
         src_corr.transpose(2, 1).contiguous()).transpose(2,
                                                          1).contiguous()
     src_centered = src_k - src_k.mean(dim=2, keepdim=True)
     src_corr_centered = src_corr_k - src_corr_k.mean(dim=2, keepdim=True)
     H = torch.matmul(src_centered,
                      src_corr_centered.transpose(2, 1).contiguous()).cpu()
     R = []
     for i in range(src.size(0)):
         try:
             u, s, v = torch.svd(H[i])
         except:
             print(H[i])
         r = torch.matmul(v, u.transpose(1, 0)).contiguous()
         r_det = torch.det(r).item()
         diag = torch.from_numpy(
             np.array([[1.0, 0, 0], [0, 1.0, 0],
                       [0, 0, r_det]]).astype('float32')).to(v.device)
         r = torch.matmul(torch.matmul(v, diag),
                          u.transpose(1, 0)).contiguous()
         R.append(r)
     R = torch.stack(R, dim=0).cuda()
     t = torch.matmul(-R, src_k.mean(
         dim=2, keepdim=True)) + src_corr_k.mean(dim=2, keepdim=True)
     return R, t.view(batch_size, 3), perm_matrix_norm, x
Beispiel #23
0
def eigs_comp(A):
    p1 = A[:,0,1].pow(2) + A[:,0,2].pow(2) + A[:,1,2].pow(2)
    
    diag_matrix_flag = p1==0
    A_diag = A.clone()
    A_diag = A_diag[diag_matrix_flag]
    A_non_diag = A.clone()
    A_non_diag = A_non_diag[~diag_matrix_flag]
    #p1_diag = p1[diag_matrix_flag]
    p1_non_diag = p1.clone()
    p1_non_diag = p1_non_diag[~diag_matrix_flag]

    eig1_diag = torch.zeros(A_diag.shape[0],1, device='cuda')
    eig2_diag = torch.zeros(A_diag.shape[0],1, device='cuda')
    eig3_diag = torch.zeros(A_diag.shape[0],1, device='cuda')
    if A_diag.shape[0]>0:
        # for those diagonal matrix
        eig1_diag_tmp = A_diag[:,0,0].unsqueeze(1)
        eig2_diag_tmp = A_diag [:,1,1].unsqueeze(1)
        eig3_diag_tmp = A_diag[:,2,2].unsqueeze(1)
        EIG = torch.cat((eig1_diag_tmp,eig2_diag_tmp),1)
        EIG = torch.cat((EIG,eig3_diag_tmp),1)
        EIG, ind = EIG.sort(1)
        eig1_diag = EIG[:,2].unsqueeze(1)
        eig2_diag = EIG[:,1].unsqueeze(1)
        eig3_diag = EIG[:,0].unsqueeze(1)
        
    eig1_non_diag = torch.zeros(A_non_diag.shape[0],1, device='cuda')
    eig2_non_diag = torch.zeros(A_non_diag.shape[0],1, device='cuda')
    eig3_non_diag = torch.zeros(A_non_diag.shape[0],1, device='cuda') 
    if A_non_diag.shape[0]>0:
        # for those non-diagonal matrix
        tr_A = (A_non_diag[:,0,0]+A_non_diag[:,1,1]+A_non_diag[:,2,2])/3               # trace(A) is the sum of all diagonal values
        p2 = (A_non_diag[:,0,0] - tr_A).pow(2) + (A_non_diag[:,1,1] - tr_A).pow(2) + (A_non_diag[:,2,2]- tr_A).pow(2) + 2 * p1_non_diag
        p3 = torch.sqrt(p2/6)
        I_matrix = torch.eye(3,device='cuda').repeat(A_non_diag.shape[0],1,1)
        tmp_tr_A = tr_A.view(A_non_diag.shape[0],1,1).repeat(1,3,3) 
        tmp_p = p3.view(A_non_diag.shape[0],1,1).repeat(1,3,3)
        B = (1 / tmp_p) * (A_non_diag - tmp_tr_A * I_matrix)    # I is the identity matrix
        tmp_det_B = torch.det(B)/2
            
        #grads1={}
        #def save_grad1(name):
        #    def hook(grad):
        #        grads1[name]=grad
        #    return hook
        #tmp_det_B.register_hook(save_grad1('tmp_det_B'))  
    
        # In exact arithmetic for a symmetric matrix  -1 <= tmp <= 1
        # but computation error can leave it slightly outside this range.
        pi_tmp = LLTMFunction.apply(tmp_det_B)
            
        # the eigenvalues satisfy eig3 <= eig2 <= eig1
        eig1_non_diag = tr_A + 2 * p3 * torch.cos(pi_tmp)
        eig3_non_diag = tr_A + 2 * p3 * torch.cos(pi_tmp + (2*math.pi/3))
        eig2_non_diag = 3 * tr_A - eig1_non_diag - eig3_non_diag     # since trace(A) = eig1 + eig2 + eig3
        
        eig1_non_diag = eig1_non_diag.unsqueeze(1)
        eig2_non_diag = eig2_non_diag.unsqueeze(1)
        eig3_non_diag = eig3_non_diag.unsqueeze(1)
    
    eig1 = torch.zeros(A.shape[0],1, device='cuda')
    eig1[diag_matrix_flag] = eig1_diag
    eig1[~diag_matrix_flag] = eig1_non_diag
    eig2 = torch.zeros(A.shape[0],1, device='cuda')
    eig2[diag_matrix_flag]=eig2_diag
    eig2[~diag_matrix_flag]=eig2_non_diag
    eig3 = torch.zeros(A.shape[0],1, device='cuda')
    eig3[diag_matrix_flag]=eig3_diag
    eig3[~diag_matrix_flag]=eig3_non_diag
    
    A_eig = torch.zeros(A.shape,device='cuda')
    A_eig[:,0,0]=eig1.squeeze()
    A_eig[:,1,1]=eig2.squeeze()
    A_eig[:,2,2]=eig3.squeeze()   
    
    return A_eig
Beispiel #24
0
    def _compute_cov(self, mu1, mu2, mu3, lengthscale, cov, var, Beta):
        # N x D
        mu1 = mu1
        # 1 x D
        mu2 = mu2
        # E
        mu3 = mu3
        # E by D
        cov_1 = lengthscale
        # D x D
        cov_2 = cov
        # E x 1
        var = var

        # E x N
        Beta = Beta

        # print(mu1.size())
        # print(mu2.size())
        # print(mu3.size())
        # print(lengthscale.size())
        # print(cov.size())
        # print(var.size())
        # print(Beta.size())
        # E x D x D tensor
        range_lis = range(0, lengthscale.size()[0])
        Mat1 = torch.stack(list(map(lambda i: lengthscale[i, :].diag(), range_lis)))
        # E x D x D tensor
        Mat2 = torch.stack(list(map(lambda i: torch.potrs(torch.eye(cov_2.size()[0]), (Mat1[i, :, :] + cov_2).potrf(upper=False), upper=False), range_lis)))
        # # Mat3 = torch.stack(Mat1) + torch.matmul(torch.stack(Mat1), torch.matmul(Mat2, torch.stack(Mat1)))
        # # Mat4 = cov_2 + torch.matmul(cov_2, torch.matmul(Mat2, cov_2))
        # N x E x D x 1
        Mu = mu1.unsqueeze(1).unsqueeze(-1) \
             - torch.matmul(Mat1,
                            torch.matmul(Mat2, mu1.unsqueeze(1).unsqueeze(-1))) \
             + mu2.unsqueeze(1).unsqueeze(-1) - torch.matmul(cov_2, torch.matmul(Mat2,
                                                                                 mu2.unsqueeze(1).unsqueeze(-1)))
        #Mu = torch.matmul(cov_2, torch.matmul(Mat2, mu1.unsqueeze(1).unsqueeze(-1))) - torch.matmul(cov_2, torch.matmul(Mat2, mu2.unsqueeze(1).unsqueeze(-1)))
        # N x E x D
        Mu = Mu.squeeze(-1)
        # E x 1
        Det1_func = torch.stack(list(map(lambda i: torch.det(Mat1[i, :, :]), range_lis)))
        Det2_func = torch.stack(list(map(lambda i: torch.det(Mat1[i, :, :] + cov_2), range_lis)))
        #
        # E x 1
        Det = torch.mul(torch.mul(Det1_func ** 0.5,  Det2_func ** -0.5), torch.tensor(var))
        # print(Det.size())
        # ####
        #
        # N x E x 1 x 1
        Mat3 = torch.matmul((mu1 - mu2).unsqueeze(1).unsqueeze(1),
                            torch.matmul(Mat2, (mu1 - mu2).unsqueeze(1).unsqueeze(-1)))
        # print(Mat3.size())
        Z = torch.mul(Det, torch.exp(-0.5 * Mat3.squeeze(-1)))

        # print(Z.transpose(dim0=-1, dim1=0).size())
        #
        #N x E x D
        Cov_xy = torch.mul(Beta, torch.mul(Z, Mu).transpose(dim0=-1, dim1=0))
        # E x D
        Cov_xy = torch.sum(Cov_xy, dim=-1) - torch.ger(mu2.view(-1), mu3.view(-1))
        Cov_yx = Cov_xy.transpose(dim0=0, dim1=1)

        # print(Cov_xy.size())
        return Cov_xy, Cov_yx
Beispiel #25
0
def kabsch_transformation_estimation(x1, x2, weights=None, normalize_w = True, eps = 1e-7, best_k = 0, w_threshold = 0, compute_residuals = False):
    """
    Torch differentiable implementation of the weighted Kabsch algorithm (https://en.wikipedia.org/wiki/Kabsch_algorithm). Based on the correspondences and weights calculates
    the optimal rotation matrix in the sense of the Frobenius norm (RMSD), based on the estimated rotation matrix it then estimates the translation vector hence solving
    the Procrustes problem. This implementation supports batch inputs.
    Args:
        x1            (torch array): points of the first point cloud [b,n,3]
        x2            (torch array): correspondences for the PC1 established in the feature space [b,n,3]
        weights       (torch array): weights denoting if the coorespondence is an inlier (~1) or an outlier (~0) [b,n]
        normalize_w   (bool)       : flag for normalizing the weights to sum to 1
        best_k        (int)        : number of correspondences with highest weights to be used (if 0 all are used)
        w_threshold   (float)      : only use weights higher than this w_threshold (if 0 all are used)
    Returns:
        rot_matrices  (torch array): estimated rotation matrices [b,3,3]
        trans_vectors (torch array): estimated translation vectors [b,3,1]
        res           (torch array): pointwise residuals (Eucledean distance) [b,n]
        valid_gradient (bool): Flag denoting if the SVD computation converged (gradient is valid)
    """
    if weights is None:
        weights = torch.ones(x1.shape[0],x1.shape[1]).type_as(x1).to(x1.device)

    if normalize_w:
        sum_weights = torch.sum(weights,dim=1,keepdim=True) + eps
        weights = (weights/sum_weights)

    weights = weights.unsqueeze(2)

    if best_k > 0:
        indices = np.argpartition(weights.cpu().numpy(), -best_k, axis=1)[0,-best_k:,0]
        weights = weights[:,indices,:]
        x1 = x1[:,indices,:]
        x2 = x2[:,indices,:]

    if w_threshold > 0:
        weights[weights < w_threshold] = 0


    x1_mean = torch.matmul(weights.transpose(1,2), x1) / (torch.sum(weights, dim=1).unsqueeze(1) + eps)
    x2_mean = torch.matmul(weights.transpose(1,2), x2) / (torch.sum(weights, dim=1).unsqueeze(1) + eps)

    x1_centered = x1 - x1_mean
    x2_centered = x2 - x2_mean

    cov_mat = torch.matmul(x1_centered.transpose(1, 2),
                            (x2_centered * weights))

    try:
        u, s, v = torch.svd(cov_mat)

    except Exception as e:
        r = torch.eye(3,device=x1.device)
        r = r.repeat(x1_mean.shape[0],1,1)
        t = torch.zeros((x1_mean.shape[0],3,1), device=x1.device)

        res = transformation_residuals(x1, x2, r, t)

        return r, t, res, True

    tm_determinant = torch.det(torch.matmul(v.transpose(1, 2), u.transpose(1, 2)))

    determinant_matrix = torch.diag_embed(torch.cat((torch.ones((tm_determinant.shape[0],2),device=x1.device), tm_determinant.unsqueeze(1)), 1))

    rotation_matrix = torch.matmul(v,torch.matmul(determinant_matrix,u.transpose(1,2)))

    # translation vector
    translation_matrix = x2_mean.transpose(1,2) - torch.matmul(rotation_matrix,x1_mean.transpose(1,2))

    # Residuals
    res = None
    if compute_residuals:
        res = transformation_residuals(x1, x2, rotation_matrix, translation_matrix)

    return rotation_matrix, translation_matrix, res, False
Beispiel #26
0
def kabsch_autograd(P, X, jacobian=None, use_cuda=False, eps=0.01):
    """
    Args:
        P (tensor): Measurements
        X (tensor): Scene points
        jacobian (tensor) 6x3N jacobian matrix of r&t w.r.t scene point coord.
        use_cuda (bool): Flag to indicate whether to calculate jacobian
        eps (float): Epsilon used in finite difference approximation

    Return:
        r (tensor): 3x1 rotation vector that satisfies P = RX + T
        t (tensor): 3x1 translation vector that satisfies P = RX + T
    """

    print("Running Kabsch (Autograd)")

    if use_cuda:
        print("\tUsing CUDA")
        X = X.cuda()
        P = P.cuda()

    if jacobian is None:
        r, t = kabsch(P, X)
        return r, t

    print("\tComputing jacobian")

    X.requires_grad = True

    # compute centroid as average of coordinates
    tx = torch.mean(X, 0)
    tp = torch.mean(P, 0)

    # move centroid to origin
    Xc = X.sub(tx)
    Pc = P.sub(tp)

    A = torch.mm(torch.t(Pc), Xc)

    U, S, V = torch.svd(A)

    # flag for degeneracy
    degenerate = False

    # degenerate if any singular value is zero
    if torch.numel(torch.nonzero(S)) != torch.numel(S):
        degenerate = True

    # degenerate if singular values are not distinct
    if torch.abs(S[0] - S[1]) < 1e-8 or torch.abs(
            S[0] - S[2]) < 1e-8 or torch.abs(S[1] - S[2]) < 1e-8:
        degenerate = True

    # if degenerate, use finite difference for stability
    if degenerate is True:
        X.requires_grad = False
        return None, None

    # non-degenerate case, continue with Kabsch algorithm with autograd
    Vt = torch.t(V)

    d = torch.det(torch.mm(U, Vt))

    D = torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, d]])

    if use_cuda:
        D = D.cuda()

    R = torch.mm(U, torch.mm(D, Vt))

    rod = Rodrigues.apply

    r = rod(R)  # rotation vector

    t = tp - torch.mm(R, tx.view(3, -1)).view(-1)  # translation vector

    numelR = torch.numel(r)

    # compute jacobian matrix
    for i in range(numelR):
        onehot = torch.zeros(numelR, dtype=torch.float32)
        onehot[i] = 1

        # jacobian of an element of r w.r.t X
        if use_cuda:
            r.backward(onehot.view(r.size()).cuda(), retain_graph=True)
        else:
            r.backward(onehot.view(r.size()), retain_graph=True)
        jacobian[i, :] = X.grad.data.view(-1)

        # zero the gradient for next element
        X.grad.data.zero_()

        # jacobian of an element of t w.r.t X
        if use_cuda:
            t.backward(onehot.view(t.size()).cuda(), retain_graph=True)
        else:
            t.backward(onehot.view(t.size()), retain_graph=True)
        jacobian[i + 3, :] = X.grad.data.view(-1)

        # zero the gradient for next element
        X.grad.data.zero_()

    return r.detach(), t.detach()
Beispiel #27
0
def get_surface_high_res_mesh(sdf, resolution=100):
    # get low res mesh to sample point cloud
    grid = get_grid_uniform(100)
    z = []
    points = grid['grid_points']

    for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
        z.append(sdf(pnts).detach().cpu().numpy())
    z = np.concatenate(z, axis=0)

    z = z.astype(np.float32)

    verts, faces, normals, values = measure.marching_cubes_lewiner(
        volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
                         grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
        level=0,
        spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
                 grid['xyz'][0][2] - grid['xyz'][0][1],
                 grid['xyz'][0][2] - grid['xyz'][0][1]))

    verts = verts + np.array(
        [grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])

    mesh_low_res = trimesh.Trimesh(verts, faces, vertex_normals=-normals)
    # return mesh_low_res

    components = mesh_low_res.split(only_watertight=False)
    areas = np.array([c.area for c in components], dtype=np.float)
    mesh_low_res = components[areas.argmax()]

    recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
    recon_pc = torch.from_numpy(recon_pc).float().cuda()

    # Center and align the recon pc
    s_mean = recon_pc.mean(dim=0)
    s_cov = recon_pc - s_mean
    s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
    vecs = torch.eig(s_cov, True)[1].transpose(0, 1)
    if torch.det(vecs) < 0:
        vecs = torch.mm(
            torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(),
            vecs)
    helper = torch.bmm(
        vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
        (recon_pc - s_mean).unsqueeze(-1)).squeeze()

    grid_aligned = get_grid(helper.cpu(), resolution)

    grid_points = grid_aligned['grid_points']

    g = []
    for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)):
        g.append(
            torch.bmm(
                vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
                pnts.unsqueeze(-1)).squeeze() + s_mean)
    grid_points = torch.cat(g, dim=0)

    # MC to new grid
    points = grid_points
    z = []
    for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
        z.append(sdf(pnts).detach().cpu().numpy())
    z = np.concatenate(z, axis=0)

    meshexport = None
    if (not (np.min(z) > 0 or np.max(z) < 0)):

        z = z.astype(np.float32)

        verts, faces, normals, values = measure.marching_cubes_lewiner(
            volume=z.reshape(grid_aligned['xyz'][1].shape[0],
                             grid_aligned['xyz'][0].shape[0],
                             grid_aligned['xyz'][2].shape[0]).transpose(
                                 [1, 0, 2]),
            level=0,
            spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
                     grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
                     grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))

        verts = torch.from_numpy(verts).cuda().float()
        verts = torch.bmm(
            vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
            verts.unsqueeze(-1)).squeeze()
        verts = (verts + grid_points[0]).cpu().numpy()

        meshexport = trimesh.Trimesh(verts, faces, vertex_normals=-normals)

    return meshexport
Beispiel #28
0
    def __init__(
        self,
        inequality_constraints: Optional[Tuple[Tensor, Tensor]] = None,
        equality_constraints: Optional[Tuple[Tensor, Tensor]] = None,
        bounds: Optional[Tensor] = None,
        interior_point: Optional[Tensor] = None,
    ) -> None:
        r"""Initialize DelaunayPolytopeSampler.

        Args:
            inequality_constraints: Tensors `(A, b)` describing inequality
                constraints `A @ x <= b`, where `A` is a `n_ineq_con x d`-dim
                Tensor and `b` is a `n_ineq_con x 1`-dim Tensor, with `n_ineq_con`
                the number of inequalities and `d` the dimension of the sample space.
            equality_constraints: Tensors `(C, d)` describing the equality constraints
                `C @ x = d`, where `C` is a `n_eq_con x d`-dim Tensor and `d` is a
                `n_eq_con x 1`-dim Tensor with `n_eq_con` the number of equalities.
            bounds: A `2 x d`-dim tensor of box bounds, where `inf` (`-inf`) means
                that the respective dimension is unbounded from above (below).
            interior_point: A `d x 1`-dim Tensor representing a point in the
                (relative) interior of the polytope. If omitted, determined
                automatically by solving a Linear Program.

        Warning: The vertex enumeration performed in this algorithm can become
        extremely costly if there are a large number of inequalities. Similarly,
        the triangulation can get very expensive in high dimensions. Only use
        this algorithm for moderate dimensions / moderately complex constraint sets.
        An alternative is the `HitAndRunPolytopeSampler`.
        """
        super().__init__(
            inequality_constraints=inequality_constraints,
            equality_constraints=equality_constraints,
            bounds=bounds,
            interior_point=interior_point,
        )
        # shift coordinate system to be anchored at x0
        new_b = self.b - self.A @ self.x0
        if self.new_A.shape[-1] < 2:
            # if the polytope is in dim 1 (i.e. a line segment) Qhull won't work
            tshlds = new_b / self.new_A
            neg = self.new_A < 0
            self.y_min = tshlds[neg].max()
            self.y_max = tshlds[~neg].min()
            self.dim = 1
        else:
            # Qhull expects inputs of the form A @ x + b <= 0, so we need to negate here
            halfspaces = torch.cat([self.new_A, -new_b], dim=-1).cpu().numpy()
            vertices = HalfspaceIntersection(
                halfspaces=halfspaces,
                interior_point=np.zeros(self.new_A.shape[-1])).intersections
            self.dim = vertices.shape[-1]
            try:
                delaunay = Delaunay(vertices)
            except ValueError as e:
                if "Points cannot contain NaN" in str(e):
                    raise ValueError("Polytope is unbounded.")
                raise e  # pragma: no cover
            polytopes = torch.from_numpy(
                np.array([delaunay.points[s]
                          for s in delaunay.simplices]), ).to(self.A)
            volumes = torch.stack(
                [torch.det(p[1:] - p[0]).abs() for p in polytopes])
            self._polytopes = polytopes
            self._p = volumes / volumes.sum()
Beispiel #29
0
def umeyama(src, dst, estimate_scale):
    """Estimate N-D similarity transformation with or without scaling.
    Parameters
    ----------
    src : (M, N) array
        Source coordinates.
    dst : (M, N) array
        Destination coordinates.
    estimate_scale : bool
        Whether to estimate scaling factor.
    Returns
    -------
    T : (N + 1, N + 1)
        The homogeneous similarity transformation matrix. The matrix contains
        NaN values only if the problem is not well-conditioned.
    References
    ----------
    .. [1] "Least-squares estimation of transformation parameters between two
            point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
    """

    num = src.shape[0]
    dim = src.shape[1]
    device = src.device

    if torch.sum(src - dst) == 0:
        return torch.eye(4).to(device)

    # Compute mean of src and dst.
    src_mean = src.mean(axis=0)
    dst_mean = dst.mean(axis=0)

    # Subtract mean from src and dst.
    src_demean = src - src_mean
    dst_demean = dst - dst_mean

    # Eq. (38).
    A = dst_demean.T @ src_demean / num

    # Eq. (39).
    d = torch.ones((dim, )).type_as(src).to(device)
    if torch.det(A) < 0:
        d[dim - 1] = -1

    T = torch.eye(dim + 1).type_as(src).to(device)

    U, S, Vt = torch.svd(A)
    V = Vt.T

    # Eq. (40) and (43).
    rank = torch.matrix_rank(A)
    if rank == 0:
        # return float('nan') * T
        return torch.eye(4).to(device)
    elif rank == dim - 1:
        if torch.det(U) * torch.det(V) > 0:
            T[:dim, :dim] = U @ V
        else:
            s = d[dim - 1]
            d[dim - 1] = -1
            T[:dim, :dim] = U @ torch.diag(d) @ V
            d[dim - 1] = s
    else:
        T[:dim, :dim] = U @ torch.diag(d) @ V

    if estimate_scale:
        # Eq. (41) and (42).
        scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
    else:
        scale = 1.0

    T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
    T[:dim, :dim] *= scale
    return T
Beispiel #30
0
    def get_score_gpu(self,
                      input: Tensor,
                      gradient: Tensor,
                      size: int = 2,
                      method='noble',
                      **kwargs) -> Tensor:
        channels, height, width = gradient.shape[1], gradient.shape[
            2], gradient.shape[3]

        # 3-dim : (height x width) x channels x 1
        X = gradient.flatten(2).permute(2, 1, 0)

        # 3-dim : (height x width) x channels x channels
        X = torch.bmm(X, X.permute(0, 2, 1))

        # 2-dim : (height x width) x (channels x channels)
        X = X.reshape(-1, channels * channels)

        # 4-dim : 1 x (channels x channels) x height x width
        X = X.reshape(height, width,
                      channels * channels).permute(2, 0, 1).unsqueeze(0)

        kernel_size = 2 * size + 1
        weight = torch.ones(
            size=[channels * channels, 1, kernel_size, kernel_size],
            dtype=X.dtype,
            device=X.device)

        # Get a sensitivity matrix X for a rgb perturbation vector
        # 4-dim : 1 x (channels x channels) x height x width
        X = F.conv2d(X,
                     weight=weight,
                     bias=None,
                     stride=1,
                     padding=size,
                     dilation=1,
                     groups=channels * channels)
        if method == 'noble':
            # 3-dim : (height x width) x channels x channels
            X_mat = X.unsqueeze(0).reshape(channels, channels, height,
                                           width).permute(2, 3, 0, 1).reshape(
                                               -1, channels, channels)

            X_det = torch.det(X_mat)
            X_trace = X_mat.diagonal(dim1=-2, dim2=-1).sum(-1)
            score = 2 * X_det / (X_trace + 1e-6)
            score = score.reshape(height, width)
        elif method == 'fro':
            score = X.norm(dim=1, keepdim=True)
        elif method == 'shi-tomasi':
            X_mat = X.unsqueeze(0).reshape(channels, channels, height,
                                           width).permute(2, 3, 0, 1).reshape(
                                               -1, channels, channels)

            S, _ = torch.symeig(X_mat, eigenvectors=False)

            return S[:, -1].reshape(height, width)

        elif method == 'sampling':
            sampling_method, num_samples = self.parse_sampling_kwargs(**kwargs)

            # unfold = F.unfold(input, kernel_size, dilation=1, padding=size, stride=1).reshape(
            #     1, channels, -1, height * width).squeeze(0).permute(2, 0, 1)
            # diff = unfold - unfold.mean(2, keepdim=True)

            # cov = diff.bmm(diff.permute(0, 2, 1)) / diff.shape[-1]
            # L = torch.cholesky(cov, upper=False).to(dtype=X.dtype)

            # 3-dim : (height x width) x (channels x channels)
            X_mat = X.unsqueeze(0).reshape(channels, channels, height,
                                           width).permute(2, 3, 0, 1).reshape(
                                               -1, channels, channels)

            samples = torch.randn(
                (X_mat.shape[0], num_samples, X_mat.shape[2]),
                dtype=X_mat.dtype,
                device=X_mat.device)
            samples /= samples.norm(dim=-1, keepdim=True)
            #samples = torch.bmm(samples, L.permute(0, 2, 1))

            score = torch.bmm(samples, X_mat)

            if sampling_method == 'std':
                score = (score * samples).sum(-1).sqrt().std(-1).reshape(
                    height, width)
            elif sampling_method == 'mean':
                score = (score * samples).sum(-1).sqrt().mean(-1).reshape(
                    height, width)
            elif sampling_method == 'min':
                score = (score * samples).sum(-1).sqrt().min(-1)[0].reshape(
                    height, width)
            elif sampling_method == 'max':
                score = (score * samples).sum(-1).sqrt().max(-1)[0].reshape(
                    height, width)
        else:
            raise Exception(
                'method should be one of {\'noble\', \'fro\', \'sampling\'}.')

        if score.dim() == 2:
            score = score.unsqueeze(0).unsqueeze(0)

        return score.repeat(1, channels, 1, 1)