def basis_representation(self, matrix_representation): """Calculate the coefficients of given matrix in the basis. Compute a 1d-array that corresponds to the input matrix in the basis representation. Parameters ---------- matrix_representation : array-like, shape=[..., n, n] Matrix. Returns ------- basis_representation : array-like, shape=[..., dim] Representation in the basis. """ if self.n == 2: return matrix_representation[..., 1, 0][..., None] if self.n == 3: vec = gs.stack([ matrix_representation[..., 2, 1], matrix_representation[..., 0, 2], matrix_representation[..., 1, 0], ]) return gs.transpose(vec) return gs.triu_to_vec(matrix_representation, k=1)
def __sample_spd(self, samples): n = self.mean.shape[-1] sym_matrix = self.manifold.logm(self.mean) mean_euclidean = gs.hstack( (gs.diagonal(sym_matrix)[None, :], gs.sqrt(2.0) * gs.triu_to_vec(sym_matrix, k=1)[None, :]))[0] samples_euclidean = gs.random.multivariate_normal( mean_euclidean, self.cov, (samples, )) diag = samples_euclidean[:, :n] off_diag = samples_euclidean[:, n:] / gs.sqrt(2.0) samples_sym = gs.mat_from_diag_triu_tril(diag=diag, tri_upp=off_diag, tri_low=off_diag) samples_spd = self.manifold.expm(samples_sym) return samples_spd
def sample(self, n_samples): """Generate samples for SPD manifold""" if isinstance(self.manifold.metric, SPDMetricLogEuclidean): sym_matrix = self.manifold.logm(self.mean) mean_euclidean = gs.hstack(( gs.diagonal(sym_matrix)[None, :], gs.sqrt(2.0) * gs.triu_to_vec(sym_matrix, k=1)[None, :], ))[0] _samples = self.samples_sym(mean_euclidean, self.cov, n_samples) else: samples_sym = self.samples_sym(gs.zeros(self.manifold.dim), self.cov, n_samples) mean_half = self.manifold.powerm(self.mean, 0.5) _samples = Matrices.mul(mean_half, samples_sym, mean_half) return self.manifold.expm(_samples)
def to_vector(mat): """Convert a symmetric matrix into a vector. Parameters ---------- mat : array-like, shape=[..., n, n] Matrix. Returns ------- vec : array-like, shape=[..., n(n+1)/2] Vector. """ if not gs.all(Matrices.is_symmetric(mat)): logging.warning("non-symmetric matrix encountered.") mat = Matrices.to_symmetric(mat) return gs.triu_to_vec(mat)