Пример #1
0
    def register_point_sets(self, x):

        Vs = x["coordinates_ds"]
        features = x["features"]
        features_w = x["att"]
        repeat_list = [len(V) for V in Vs]
        init_R, init_t = get_init_transformation_list(
            Vs, self.params.get("mean_init", True))

        TVs = init_R @ Vs + init_t

        X = TensorList()
        Q = TensorList()
        mu = TensorList()
        for TV, Fs in zip(TVs, features):
            if self.params.cluster_init == "box":
                Xi = get_randn_box_cluster_means_list(TV, self.params.K)
            else:
                Xi = get_randn_sphere_cluster_means_list(
                    TV, self.params.K,
                    self.params.get("cluster_mean_scale", 1.0))
            Q.append(
                get_scaled_cluster_precisions_list(
                    TV, Xi, self.params.cluster_precision_scale))
            X.append(Xi.T)

        if self.params.feature_distr_parameters.model == "vonmises":
            feature_distr = feature_models.VonMisesModelList(
                self.params.feature_distr_parameters,
                self.params.K,
                features.detach(),
                self.feature_s,
                repeat_list=repeat_list,
                mu=mu)
        elif self.params.feature_distr_parameters.model == "none":
            feature_distr = feature_models.BaseFeatureModel()
        else:
            feature_distr = feature_models.VonMisesModelList(
                self.params.feature_distr_parameters,
                self.params.K,
                features.detach(),
                self.feature_s,
                repeat_list=repeat_list)

        feature_distr.to(self.params.device)

        X = TensorListList(X, repeat=repeat_list)
        self.betas = get_default_beta(Q, self.params.gamma)

        # Compute the observation weights
        if self.params.use_dare_weighting:
            observation_weights = empirical_estimate(Vs, self.params.ow_args)
            ow_reg_factor = 8.0
            ow_mean = observation_weights.mean(dim=0, keepdim=True)
            for idx in range(len(observation_weights)):
                for idxx in range(len(observation_weights[idx])):
                    observation_weights[idx][idxx][observation_weights[idx][idxx] > ow_reg_factor * ow_mean[idx][idxx]] \
                        = ow_reg_factor * ow_mean[idx][idxx]

        else:
            observation_weights = 1.0

        ds = TVs.permute(1, 0).sqe(X).permute(1, 0)

        if self.params.debug:
            self.visdom.register(
                dict(pcds=Vs[0].cpu(), X=X[0][0].cpu(), c=None),
                'point_clouds', 2, 'init')
            time.sleep(1)

        Rs = init_R.to(self.params.device)
        ts = init_t.to(self.params.device)

        self.betas = TensorListList(self.betas, repeat=repeat_list)
        QL = TensorListList(Q, repeat=repeat_list)
        Riter = TensorListList()
        titer = TensorListList()
        TVs_iter = TensorListList()
        priors = 1
        for i in range(self.params.num_iters):
            if i in self.params.backprop_iter:
                features_f = features
                if self.params.use_attention:
                    features_w_f = features_w
                else:
                    features_w_f = 1.0
            else:
                features_f = features.detach()
                if self.params.use_attention:
                    features_w_f = features_w.detach()
                else:
                    features_w_f = 1.0
                feature_distr.detach()
                ds = ds.detach()
                QL = QL.detach()
                X = X.detach()

            Qt = QL.permute(1, 0)

            ap = priors * (-0.5 * ds * QL).exp() * QL.pow(1.5)

            if i > 0:
                pyz_feature = feature_distr.posteriors(features_f)
            else:
                pyz_feature = 1.0

            a = ap * pyz_feature

            ac_den = a.sum(dim=-1, keepdim=True) + self.betas
            a = a / ac_den  # normalize row-wise

            a = a * observation_weights * features_w_f

            L = a.sum(dim=-2, keepdim=True).permute(1, 0)
            W = (Vs @ a) * QL

            b = L * Qt  # weights, b
            mW = W.sum(dim=-1, keepdim=True)
            mX = (b.permute(1, 0) @ X).permute(1, 0)
            z = L.permute(1, 0) @ Qt
            P = (W @ X).permute(1, 0) - mX @ mW.permute(1, 0) / z

            # Compute R and t
            svd_list_list = P.cpu().svd()
            Rs = TensorListList()
            for svd_list in svd_list_list:
                Rs_list = TensorList()
                for svd in svd_list:
                    uu, vv = svd.U, svd.V
                    vvt = vv.permute(1, 0)
                    detuvt = uu @ vvt
                    detuvt = detuvt.det()
                    S = torch.ones(1, 3)
                    S[:, -1] = detuvt
                    Rs_list.append((uu * S) @ vvt)

                Rs.append(Rs_list)

            Rs = Rs.to(self.params.device)
            Riter.append(Rs)
            ts = (mX - Rs @ mW) / z
            titer.append(ts)

            TVs = Rs @ Vs + ts

            TVs_iter.append(TVs.clone())
            if self.params.debug:
                self.visdom.register(
                    dict(pcds=TVs[0].cpu(), X=X[0][0].cpu(), c=None),
                    'point_clouds', 2, 'registration-iter')
                time.sleep(0.2)

            # Update X
            den = L.sum_list()

            if self.params.fix_cluster_pos_iter < i:
                X = (TVs @ a).permute(1, 0)
                X = TensorListList(X.sum_list() / den, repeat_list)

            # Update Q
            ds = TVs.permute(1, 0).sqe(X).permute(1, 0)

            wn = (a * ds).sum(dim=-2, keepdim=True).sum_list()
            Q = (3 * den /
                 (wn.permute(1, 0) + 3 * den * self.params.epsilon)).permute(
                     1, 0)
            QL = TensorListList(Q, repeat=repeat_list)

            feature_distr.maximize(a=a, y=features_f, den=den)
            if self.params.get("update_priors", False):
                priors = TensorListList(den.permute(1, 0) /
                                        ((self.params.gamma + 1) * den.sum()),
                                        repeat=repeat_list)

        if self.params.use_attention:
            out = dict(Rs=Rs,
                       ts=ts,
                       X=X,
                       Riter=Riter[:-1],
                       titer=titer[:-1],
                       Vs=TVs,
                       Vs_iter=TVs_iter[:-1],
                       ow=observation_weights,
                       features_w=features_w_f)
        else:
            out = dict(Rs=Rs,
                       ts=ts,
                       X=X,
                       Riter=Riter[:-1],
                       titer=titer[:-1],
                       Vs=TVs,
                       Vs_iter=TVs_iter[:-1],
                       ow=observation_weights)
        return out
Пример #2
0
class GaussNewtonCG:
    def __init__(self,
                 problem: MinimizationProblem,
                 variable: TensorList,
                 cg_eps=0.0,
                 fletcher_reeves=True,
                 standard_alpha=True,
                 direction_forget_factor=0,
                 step_alpha=1.0):

        self.fletcher_reeves = fletcher_reeves
        self.standard_alpha = standard_alpha
        self.direction_forget_factor = direction_forget_factor

        # State
        self.p = None
        self.rho = torch.ones(1)
        self.r_prev = None

        # Right hand side
        self.b = None

        self.problem = problem
        self.x = variable

        self.cg_eps = cg_eps
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

        self.residuals = torch.zeros(0)
        self.external_losses = []
        self.internal_losses = []
        self.gradient_mags = torch.zeros(0)

        self.step_alpha = step_alpha

    def clear_temp(self):
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

    def run(self, num_cg_iter, num_gn_iter=None):

        self.problem.initialize()

        if isinstance(num_cg_iter, int):
            if num_gn_iter is None:
                raise ValueError(
                    'Must specify number of GN iter if CG iter is constant')
            num_cg_iter = [num_cg_iter] * num_gn_iter

        num_gn_iter = len(num_cg_iter)
        if num_gn_iter == 0:
            return

        # with torch.autograd.profiler.profile(use_cuda=True) as prof:
        for cg_iter in num_cg_iter:
            self.run_GN_iter(cg_iter)

        self.x.detach_()
        self.clear_temp()

        return self.external_losses, self.internal_losses, self.residuals

    def run_GN_iter(self, num_cg_iter):

        self.x.requires_grad_(True)

        self.f0 = self.problem(self.x)
        self.g = self.f0.detach()
        self.g.requires_grad_(True)
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g,
                                create_graph=True))  # df/dx^t @ f0
        self.b = -self.dfdxt_g.detach()

        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += self.step_alpha * delta_x
        self.step_alpha = min(self.step_alpha * 1.2, 1.0)

    def reset_state(self):
        self.p = None
        self.rho = torch.ones(1)
        self.r_prev = None

    def run_CG(self, num_iter, x=None, eps=0.0):
        """Main conjugate gradient method"""

        # Apply forgetting factor
        if self.direction_forget_factor == 0:
            self.reset_state()
        elif self.p is not None:
            self.rho /= self.direction_forget_factor

        if x is None:
            r = self.b.clone()
        else:
            r = self.b - self.A(x)

        # Loop over iterations
        for ii in range(num_iter):

            z = self.problem.M1(r)  # Preconditioner

            rho1 = self.rho
            self.rho = self.ip(r, z)

            if self.p is None:
                self.p = z.clone()
            else:
                if self.fletcher_reeves:
                    beta = self.rho / rho1
                else:
                    rho2 = self.ip(self.r_prev, z)
                    beta = (self.rho - rho2) / rho1

                beta = beta.clamp(0)
                self.p = z + self.p * beta

            q = self.A(self.p)
            pq = self.ip(self.p, q)

            if self.standard_alpha:
                alpha = self.rho / pq
            else:
                alpha = self.ip(self.p, r) / pq

            # Save old r for PR formula
            if not self.fletcher_reeves:
                self.r_prev = r.clone()

            # Form new iterate
            if x is None:
                x = self.p * alpha
            else:
                x += self.p * alpha

            if ii < num_iter - 1:
                r -= q * alpha

        return x, []

    def A(self, x):
        dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                     self.g,
                                     x,
                                     retain_graph=True)
        return TensorList(
            torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))

    def ip(self, a, b):
        return self.problem.ip_input(a, b)