def register_point_sets(self, x): Vs = x["coordinates_ds"] features = x["features"] features_w = x["att"] repeat_list = [len(V) for V in Vs] init_R, init_t = get_init_transformation_list( Vs, self.params.get("mean_init", True)) TVs = init_R @ Vs + init_t X = TensorList() Q = TensorList() mu = TensorList() for TV, Fs in zip(TVs, features): if self.params.cluster_init == "box": Xi = get_randn_box_cluster_means_list(TV, self.params.K) else: Xi = get_randn_sphere_cluster_means_list( TV, self.params.K, self.params.get("cluster_mean_scale", 1.0)) Q.append( get_scaled_cluster_precisions_list( TV, Xi, self.params.cluster_precision_scale)) X.append(Xi.T) if self.params.feature_distr_parameters.model == "vonmises": feature_distr = feature_models.VonMisesModelList( self.params.feature_distr_parameters, self.params.K, features.detach(), self.feature_s, repeat_list=repeat_list, mu=mu) elif self.params.feature_distr_parameters.model == "none": feature_distr = feature_models.BaseFeatureModel() else: feature_distr = feature_models.VonMisesModelList( self.params.feature_distr_parameters, self.params.K, features.detach(), self.feature_s, repeat_list=repeat_list) feature_distr.to(self.params.device) X = TensorListList(X, repeat=repeat_list) self.betas = get_default_beta(Q, self.params.gamma) # Compute the observation weights if self.params.use_dare_weighting: observation_weights = empirical_estimate(Vs, self.params.ow_args) ow_reg_factor = 8.0 ow_mean = observation_weights.mean(dim=0, keepdim=True) for idx in range(len(observation_weights)): for idxx in range(len(observation_weights[idx])): observation_weights[idx][idxx][observation_weights[idx][idxx] > ow_reg_factor * ow_mean[idx][idxx]] \ = ow_reg_factor * ow_mean[idx][idxx] else: observation_weights = 1.0 ds = TVs.permute(1, 0).sqe(X).permute(1, 0) if self.params.debug: self.visdom.register( dict(pcds=Vs[0].cpu(), X=X[0][0].cpu(), c=None), 'point_clouds', 2, 'init') time.sleep(1) Rs = init_R.to(self.params.device) ts = init_t.to(self.params.device) self.betas = TensorListList(self.betas, repeat=repeat_list) QL = TensorListList(Q, repeat=repeat_list) Riter = TensorListList() titer = TensorListList() TVs_iter = TensorListList() priors = 1 for i in range(self.params.num_iters): if i in self.params.backprop_iter: features_f = features if self.params.use_attention: features_w_f = features_w else: features_w_f = 1.0 else: features_f = features.detach() if self.params.use_attention: features_w_f = features_w.detach() else: features_w_f = 1.0 feature_distr.detach() ds = ds.detach() QL = QL.detach() X = X.detach() Qt = QL.permute(1, 0) ap = priors * (-0.5 * ds * QL).exp() * QL.pow(1.5) if i > 0: pyz_feature = feature_distr.posteriors(features_f) else: pyz_feature = 1.0 a = ap * pyz_feature ac_den = a.sum(dim=-1, keepdim=True) + self.betas a = a / ac_den # normalize row-wise a = a * observation_weights * features_w_f L = a.sum(dim=-2, keepdim=True).permute(1, 0) W = (Vs @ a) * QL b = L * Qt # weights, b mW = W.sum(dim=-1, keepdim=True) mX = (b.permute(1, 0) @ X).permute(1, 0) z = L.permute(1, 0) @ Qt P = (W @ X).permute(1, 0) - mX @ mW.permute(1, 0) / z # Compute R and t svd_list_list = P.cpu().svd() Rs = TensorListList() for svd_list in svd_list_list: Rs_list = TensorList() for svd in svd_list: uu, vv = svd.U, svd.V vvt = vv.permute(1, 0) detuvt = uu @ vvt detuvt = detuvt.det() S = torch.ones(1, 3) S[:, -1] = detuvt Rs_list.append((uu * S) @ vvt) Rs.append(Rs_list) Rs = Rs.to(self.params.device) Riter.append(Rs) ts = (mX - Rs @ mW) / z titer.append(ts) TVs = Rs @ Vs + ts TVs_iter.append(TVs.clone()) if self.params.debug: self.visdom.register( dict(pcds=TVs[0].cpu(), X=X[0][0].cpu(), c=None), 'point_clouds', 2, 'registration-iter') time.sleep(0.2) # Update X den = L.sum_list() if self.params.fix_cluster_pos_iter < i: X = (TVs @ a).permute(1, 0) X = TensorListList(X.sum_list() / den, repeat_list) # Update Q ds = TVs.permute(1, 0).sqe(X).permute(1, 0) wn = (a * ds).sum(dim=-2, keepdim=True).sum_list() Q = (3 * den / (wn.permute(1, 0) + 3 * den * self.params.epsilon)).permute( 1, 0) QL = TensorListList(Q, repeat=repeat_list) feature_distr.maximize(a=a, y=features_f, den=den) if self.params.get("update_priors", False): priors = TensorListList(den.permute(1, 0) / ((self.params.gamma + 1) * den.sum()), repeat=repeat_list) if self.params.use_attention: out = dict(Rs=Rs, ts=ts, X=X, Riter=Riter[:-1], titer=titer[:-1], Vs=TVs, Vs_iter=TVs_iter[:-1], ow=observation_weights, features_w=features_w_f) else: out = dict(Rs=Rs, ts=ts, X=X, Riter=Riter[:-1], titer=titer[:-1], Vs=TVs, Vs_iter=TVs_iter[:-1], ow=observation_weights) return out
class GaussNewtonCG: def __init__(self, problem: MinimizationProblem, variable: TensorList, cg_eps=0.0, fletcher_reeves=True, standard_alpha=True, direction_forget_factor=0, step_alpha=1.0): self.fletcher_reeves = fletcher_reeves self.standard_alpha = standard_alpha self.direction_forget_factor = direction_forget_factor # State self.p = None self.rho = torch.ones(1) self.r_prev = None # Right hand side self.b = None self.problem = problem self.x = variable self.cg_eps = cg_eps self.f0 = None self.g = None self.dfdxt_g = None self.residuals = torch.zeros(0) self.external_losses = [] self.internal_losses = [] self.gradient_mags = torch.zeros(0) self.step_alpha = step_alpha def clear_temp(self): self.f0 = None self.g = None self.dfdxt_g = None def run(self, num_cg_iter, num_gn_iter=None): self.problem.initialize() if isinstance(num_cg_iter, int): if num_gn_iter is None: raise ValueError( 'Must specify number of GN iter if CG iter is constant') num_cg_iter = [num_cg_iter] * num_gn_iter num_gn_iter = len(num_cg_iter) if num_gn_iter == 0: return # with torch.autograd.profiler.profile(use_cuda=True) as prof: for cg_iter in num_cg_iter: self.run_GN_iter(cg_iter) self.x.detach_() self.clear_temp() return self.external_losses, self.internal_losses, self.residuals def run_GN_iter(self, num_cg_iter): self.x.requires_grad_(True) self.f0 = self.problem(self.x) self.g = self.f0.detach() self.g.requires_grad_(True) self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # df/dx^t @ f0 self.b = -self.dfdxt_g.detach() delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.detach_() self.x += self.step_alpha * delta_x self.step_alpha = min(self.step_alpha * 1.2, 1.0) def reset_state(self): self.p = None self.rho = torch.ones(1) self.r_prev = None def run_CG(self, num_iter, x=None, eps=0.0): """Main conjugate gradient method""" # Apply forgetting factor if self.direction_forget_factor == 0: self.reset_state() elif self.p is not None: self.rho /= self.direction_forget_factor if x is None: r = self.b.clone() else: r = self.b - self.A(x) # Loop over iterations for ii in range(num_iter): z = self.problem.M1(r) # Preconditioner rho1 = self.rho self.rho = self.ip(r, z) if self.p is None: self.p = z.clone() else: if self.fletcher_reeves: beta = self.rho / rho1 else: rho2 = self.ip(self.r_prev, z) beta = (self.rho - rho2) / rho1 beta = beta.clamp(0) self.p = z + self.p * beta q = self.A(self.p) pq = self.ip(self.p, q) if self.standard_alpha: alpha = self.rho / pq else: alpha = self.ip(self.p, r) / pq # Save old r for PR formula if not self.fletcher_reeves: self.r_prev = r.clone() # Form new iterate if x is None: x = self.p * alpha else: x += self.p * alpha if ii < num_iter - 1: r -= q * alpha return x, [] def A(self, x): dfdx_x = torch.autograd.grad(self.dfdxt_g, self.g, x, retain_graph=True) return TensorList( torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True)) def ip(self, a, b): return self.problem.ip_input(a, b)