def __init__(self, dim, latent_dim, beta_min, beta_prior=None, **kwargs): """ Initialisation. Parameters ---------- :param dim: dimension of the ambient high-dimensional sphere manifold :param latent_dim: dimension of the latent low-dimensional sphere manifold :param beta_min: minimum value of the inverse square lengthscale parameter beta Optional parameters ------------------- :param beta_prior: prior on the parameter beta :param kwargs: additional arguments """ super(NestedSphereGaussianKernel, self).__init__(has_lengthscale=False, **kwargs) self.beta_min = beta_min self.dim = dim self.latent_dim = latent_dim # Add beta parameter, corresponding to the inverse of the lengthscale parameter. beta_num_dims = 1 self.register_parameter(name="raw_beta", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, beta_num_dims))) if beta_prior is not None: self.register_prior("beta_prior", beta_prior, lambda: self.beta, lambda v: self._set_beta(v)) # A GreaterThan constraint is defined on the lengthscale parameter to guarantee positive-definiteness. # The value of beta_min can be determined e.g. experimentally. self.register_constraint("raw_beta", GreaterThan(self.beta_min)) # Add projection parameters for d in range(self.dim, self.latent_dim, -1): # Axes parameters # Register axis_name = "raw_axis_S" + str(d) # axis = torch.zeros(1, d) # axis[:, 0] = 1 axis = torch.randn(1, d) axis = axis / torch.norm(axis) axis = axis.repeat(*self.batch_shape, 1, 1) self.register_parameter(name=axis_name, parameter=torch.nn.Parameter(axis)) # Corresponding manifold axis_manifold_name = "raw_axis_S" + str(d) + "_manifold" setattr(self, axis_manifold_name, pyman_man.Sphere(d)) # Distance to axis (constant), fixed at pi/2 self.distances_to_axis = [np.pi/2 *torch.ones(1, 1) for d in range(self.dim, self.latent_dim, -1)]
# If the dimension is changed: # - beta min must be adapted # - the optimization domain must be updated as one domain per dimension is required by gpflowopt dim = 3 # Beta min value beta_min = 6.5 # True to display sphere figures (possible only if the dimension is 3 (3D graphs)) display_figures = True # Number of BO iterations nb_iter_bo = 25 # Instantiate the manifold sphere_manifold = pyman_man.Sphere(dim) # Define the function to optimize with BO # Must output a numpy [1,1] shaped array def test_function(x): if np.ndim(x) < 2: x = x[None] # Projection in tangent space of the base. # The base is fixed at (1, 0, 0, ...) for simplicity. Therefore, the tangent plane is aligned with the axis x. # The first coordinate of x_proj is always 0, so that vectors in the tangent space can be expressed in a dim-1 # dimensional space by simply ignoring the first coordinate. base = np.zeros((1, dim)) base[0, 0] = 1. x_proj = sphere_manifold.log(base, x)[0]
def optimize_reconstruction_parameters_nested_spd( x_data, x_data_projected, projection_matrix, inner_solver, cost_function=min_affine_invariant_distance_reconstruction_cost, nb_init_candidates=100, maxiter=50): """ This function computes the parameters of the mapping "projection_from_nested_spd_to_spd" from nested SPD matrices Y = W'XW to SPD matrices Xrec, so that the distance between the original data X and the reconstructed data Xrec is minimized. To do so, we consider that the nested SPD matrix Y = W'XW is the d x d upper-left part of the rotated matrix Xr = R'XR, where R = [W, V] and Xr = [Y B; B' C]. In order to recover X, we assume a constant SPD matrix C, and B = Y^0.5*K*C^0.5 to ensure the PDness of Xr, with K a contraction matrix (norm(K) <=1). We first reconstruct Xr, and then Xrec as X = RXrR'. We are minimizing the squared distance between X and Xrec, by optimizing the complement to the projection matrix V, the bottom SPD matrix C, and the contraction matrix K. The contraction matrix K is described here as a norm-1 matrix multiplied by a factor in [0,1] (unconstraintly optimized by transform it with a sigmoid function). The augmented Lagrange optimization method on Riemannian manifold is used to optimize the parameters on the product of manifolds G(D,D-d), SPD(D-d), S(d*(D-d)) and Eucl(1), while respecting the constraint W'V = 0. Parameters ---------- :param x_data: set of high-dimensional SPD matrices (N x D x D) :param x_data_projected: set of low-dimensional SPD matrices (projected from x_data) (N x d x d) :param projection_matrix: element of the Grassmann manifold (D x d) :param inner_solver: inner solver for the ALM on Riemnannian manifolds Optional parameters ------------------- :param nb_init_candidates: number of initial candidates for the optimization :param maxiter: maximum iteration of ALM solver Returns ------- :return: projection_complement_matrix: element of the Grassmann manifold (D x D-d) so that torch.mm(projection_complement_matrix.T, projection_matrix) = 0. :return: bottom_spd_matrix: bottom-right part of the rotated SPD matrix (D-d, D-d) :return: contraction_matrix: matrix whose norm is <=1 (d x D-d) """ # Dimensions dim = x_data.shape[1] latent_dim = projection_matrix.shape[1] # Product of manifolds for the optimization manifolds_list = [ pyman_man.Grassmann(dim, dim - latent_dim), pyman_man.PositiveDefinite(dim - latent_dim), pyman_man.Sphere(latent_dim * (dim - latent_dim)), pyman_man.Euclidean(1) ] product_manifold = pyman_man.Product(manifolds_list) # Constraint on the norm of the contraction matrix contraction_norm_constraint = gpytorch.constraints.Interval(0., 1.) # Constraint W'V = 0 def constraint_fct(parameters): cost = torch.norm(torch.mm(parameters[0].T, projection_matrix)) zero_element_needed_for_correct_grad = 0. * torch.norm(parameters[1]) + 0. * torch.norm(parameters[2]) + \ 0. * torch.norm(parameters[3]) return cost + zero_element_needed_for_correct_grad # Reconstruction cost def reconstruction_cost(parameters): projection_complement_matrix = parameters[0] bottom_spd_matrix = parameters[1] contraction_norm = contraction_norm_constraint.transform(parameters[3]) contraction_matrix = contraction_norm * parameters[2].view( latent_dim, dim - latent_dim) return cost_function(x_data, x_data_projected, projection_matrix, projection_complement_matrix, bottom_spd_matrix, contraction_matrix) # Generate candidate for initial data x0_candidates = [ product_manifold.rand() for i in range(nb_init_candidates) ] x0_candidates_torch = [] for x0 in x0_candidates: x0_candidates_torch.append([torch.from_numpy(x) for x in x0]) y0_candidates = [ reconstruction_cost(x0_candidates_torch[i]) for i in range(nb_init_candidates) ] # Initialize with the best of the candidates y0, x_init_idx = torch.Tensor(y0_candidates).min(0) x0 = x0_candidates[x_init_idx] # Define the optimization problem reconstruction_problem = Problem(manifold=product_manifold, cost=reconstruction_cost, arg=torch.Tensor(), verbosity=0) # Define ALM solver solver = AugmentedLagrangeMethod(maxiter=maxiter, inner_solver=inner_solver, lambdas_fact=0.05) # Solve spd_parameters_np = solver.solve(reconstruction_problem, x=x0, eq_constraints=constraint_fct) # Parameters to torch data projection_complement_matrix = torch.from_numpy(spd_parameters_np[0]) bottom_spd_matrix = torch.from_numpy(spd_parameters_np[1]) contraction_norm = contraction_norm_constraint.transform( torch.from_numpy(spd_parameters_np[3])) contraction_matrix = contraction_norm * torch.from_numpy( spd_parameters_np[2]).view(latent_dim, dim - latent_dim) return projection_complement_matrix, bottom_spd_matrix, contraction_matrix
seed = 1234 # Set numpy and pytorch seeds random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Define the dimension dim = 5 # Define the latent space dimension latent_dim = 3 # Instantiate the manifold sphere_manifold = pyman_man.Sphere(dim) latent_sphere_manifold = pyman_man.Sphere(latent_dim) # Define the test function # Parameters for the nested test function sphere_axes_test = [ torch.zeros(dim - d, dtype=torch.float64) for d in range(0, dim - latent_dim) ] for d in range(0, dim - latent_dim): sphere_axes_test[d][0] = 1. sphere_distances_to_axes_test = [ torch.tensor(np.pi / 4, dtype=torch.float64) for d in range(0, dim - latent_dim) ] # Nested test function