def __init__(self, in_features, out_features): super(hyperDense, self).__init__() self.in_features = in_features self.out_features = out_features k = (1 / in_features)**0.5 self.w = gt.ManifoldParameter( gt.ManifoldTensor(in_features, out_features).uniform_(-k, k)) self.b = gt.ManifoldParameter(gt.ManifoldTensor(out_features).zero_())
def __init__(self, input_size, hidden_size): super(hyperRNN, self).__init__() self.input_size = input_size self.hidden_size = hidden_size k = (1 / hidden_size)**0.5 self.w = gt.ManifoldParameter(gt.ManifoldTensor(hidden_size, 2, hidden_size, 2).uniform_(-k, k)) self.u = gt.ManifoldParameter(gt.ManifoldTensor(input_size, 2, hidden_size, 2).uniform_(-k, k)) self.b = gt.ManifoldParameter(gt.ManifoldTensor(hidden_size, 2, manifold=gt.PoincareBall()).zero_())
def __init__(self, input_size, hidden_size, ball): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.ball = ball k = (1 / hidden_size)**0.5 self.w_z = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k, k)) self.w_r = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k, k)) self.w_h = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k, k)) self.u_z = gt.ManifoldParameter( gt.ManifoldTensor(input_size, hidden_size).uniform_(-k, k)) self.u_r = gt.ManifoldParameter( gt.ManifoldTensor(input_size, hidden_size).uniform_(-k, k)) self.u_h = gt.ManifoldParameter( gt.ManifoldTensor(input_size, hidden_size).uniform_(-k, k)) self.b_z = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, manifold=self.ball).zero_()) self.b_r = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, manifold=self.ball).zero_()) self.b_h = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, manifold=self.ball).zero_())
def __init__(self, input_size, hidden_size): super(GRUCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size k = (1 / hidden_size)**0.5 self.w_z = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k, k)) self.w_r = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k, k)) self.w_h = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k, k)) self.u_z = gt.ManifoldParameter( gt.ManifoldTensor(input_size, hidden_size).uniform_(-k, k)) self.u_r = gt.ManifoldParameter( gt.ManifoldTensor(input_size, hidden_size).uniform_(-k, k)) self.u_h = gt.ManifoldParameter( gt.ManifoldTensor(input_size, hidden_size).uniform_(-k, k)) self.b_z = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, manifold=gt.PoincareBall()).zero_()) self.b_r = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, manifold=gt.PoincareBall()).zero_()) self.b_h = gt.ManifoldParameter( gt.ManifoldTensor(hidden_size, manifold=gt.PoincareBall()).zero_())
def unary_case(manifold): shape = shapes[type(manifold)] manopt_manifold = mannopt[type(manifold)](*shape) np.random.seed(42) rand = manopt_manifold.rand() x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold) torch.manual_seed(43) ex = geoopt.ManifoldTensor(torch.randn_like(x), manifold=manifold) v = x.proju(torch.randn_like(x)) ev = torch.randn_like(x) return UnaryCase(shape, x, ex, v, ev, manifold, manopt_manifold)
def __init__(self, input_size, hidden_size): super(EuclRNN, self).__init__() self.manifold = gt.Euclidean() self.input_size = input_size self.hidden_size = hidden_size # k = (1 / hidden_size)**0.5 k_w = (6 / (self.hidden_size + self.hidden_size)) ** 0.5 # xavier uniform k_u = (6 / (self.input_size + self.hidden_size)) ** 0.5 # xavier uniform self.w = gt.ManifoldParameter(gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k_w, k_w)) self.u = gt.ManifoldParameter(gt.ManifoldTensor(input_size, hidden_size).uniform_(-k_u, k_u)) bias = torch.randn(hidden_size) * 1e-5 self.b = gt.ManifoldParameter(bias, manifold=self.manifold)
def __init__(self, input_size, hidden_size): super(MobiusRNN, self).__init__() self.ball = gt.PoincareBall() self.input_size = input_size self.hidden_size = hidden_size # k = (1 / hidden_size)**0.5 k_w = (6 / (self.hidden_size + self.hidden_size)) ** 0.5 # xavier uniform k_u = (6 / (self.input_size + self.hidden_size)) ** 0.5 # xavier uniform self.w = gt.ManifoldParameter(gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k_w, k_w)) self.u = gt.ManifoldParameter(gt.ManifoldTensor(input_size, hidden_size).uniform_(-k_u, k_u)) bias = torch.randn(hidden_size) * 1e-5 self.b = gt.ManifoldParameter(pmath.expmap0(bias, k=self.ball.k), manifold=self.ball)
def sphere_projection_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.SphereProjection] ex = torch.randn(*shape, dtype=torch.float64) / 3 ev = torch.randn(*shape, dtype=torch.float64) / 3 x = ex # default curvature = 0 ex = x.clone() v = ev.clone() manifold = geoopt.manifolds.SphereProjection().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.manifolds.SphereProjectionExact().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def sphere_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = ex / torch.norm(ex) v = ev - (x @ ev) * x manifold = geoopt.Sphere() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def bounded_domain_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.BoundedDomain] # x is the result of projecting ex ex = torch.randn(*shape, dtype=torch.complex128) ex = geoopt.linalg.batch_linalg.sym(ex) x = ex.clone() evalues, s = geoopt.manifolds.siegel.csym_math.takagi_eig(x) evalues_tilde = torch.clamp(evalues, max=1 - 1e-5) d_tilde = torch.diag_embed(evalues_tilde).type_as(x) x = s.conj() @ d_tilde @ s.conj().transpose(-1, -2) # ev is in the tangent space ev = torch.randn(*shape, dtype=torch.complex128) / 10 ev = geoopt.linalg.batch_linalg.sym(ev) # v is the result of projecting ev at x identity = geoopt.manifolds.siegel.csym_math.identity_like(x) a = identity - (x.conj() @ x) v = geoopt.linalg.batch_linalg.sym(a @ ev @ a) manifold = geoopt.BoundedDomain() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def poincare_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.PoincareBall] ex = torch.randn(*shape, dtype=torch.float64) / 3 ev = torch.randn(*shape, dtype=torch.float64) / 3 x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex) ex = x.clone() v = ev.clone() manifold = geoopt.PoincareBall().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.PoincareBallExact().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def birkhoff_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.BirkhoffPolytope] ex = torch.randn(*shape, dtype=torch.float64).abs() ev = torch.randn(*shape, dtype=torch.float64) max_iter = 100 eps = 1e-12 tol = 1e-5 iter = 0 c = 1.0 / (torch.sum(ex, dim=-2, keepdim=True) + eps) r = 1.0 / (torch.matmul(ex, c.transpose(-1, -2)) + eps) while iter < max_iter: iter += 1 cinv = torch.matmul(r.transpose(-1, -2), ex) if torch.max(torch.abs(cinv * c - 1)) <= tol: break c = 1.0 / (cinv + eps) r = 1.0 / ((ex @ c.transpose(-1, -2)) + eps) x = ex * (r @ c) v = proju_original(x, ev) manifold = geoopt.manifolds.BirkhoffPolytope() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def origin(self, *size, dtype=None, device=None, seed=42) -> "geoopt.ManifoldTensor": """ Zero point origin. Parameters ---------- size : shape the desired shape device : torch.device the desired device dtype : torch.dtype the desired dtype seed : int ignored Returns ------- ManifoldTensor zero point on the manifold """ if dtype is None: dtype = self.k.dtype if device is None: device = self.k.device zero_point = torch.zeros(*size, dtype=dtype, device=device) zero_point[..., 0] = torch.sqrt(self.k) return geoopt.ManifoldTensor(zero_point, manifold=self)
def random_normal( self, *size, mean=0.0, std=1.0, device=None, dtype=None ) -> "geoopt.ManifoldTensor": """ Create a point on the manifold, measure is induced by Normal distribution. Parameters ---------- size : shape the desired shape mean : float|tensor mean value for the Normal distribution std : float|tensor std value for the Normal distribution device : torch.device the desired device dtype : torch.dtype the desired dtype Returns ------- ManifoldTensor random point on the manifold """ self._assert_check_shape(size2shape(*size), "x") mean = torch.as_tensor(mean, device=device, dtype=dtype) std = torch.as_tensor(std, device=device, dtype=dtype) tens = std.new_empty(*size).normal_() * std + mean return geoopt.ManifoldTensor(tens, manifold=self)
def test_deepcopy(): t = geoopt.ManifoldTensor() t = copy.deepcopy(t) assert isinstance(t, geoopt.ManifoldTensor) p = geoopt.ManifoldParameter() p = copy.deepcopy(p) assert isinstance(p, geoopt.ManifoldParameter)
def origin(self, *size, dtype=None, device=None, seed=42) -> "geoopt.ManifoldTensor": """ Zero point origin. Parameters ---------- size : shape the desired shape device : torch.device the desired device dtype : torch.dtype the desired dtype seed : int ignored Returns ------- ManifoldTensor random point on the manifold """ return geoopt.ManifoldTensor(torch.zeros(*size, dtype=dtype, device=device), manifold=self)
def test_compare_manifolds(): m1 = geoopt.Euclidean() m2 = geoopt.Euclidean(ndim=1) tensor = geoopt.ManifoldTensor(10, manifold=m1) with pytest.raises(ValueError) as e: _ = geoopt.ManifoldParameter(tensor, manifold=m2) assert e.match("Manifolds do not match")
def stereographic_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Stereographic] ex = torch.randn(*shape, dtype=torch.float64) / 3 ev = torch.randn(*shape, dtype=torch.float64) / 3 x = ex # default curvature = 0 ex = x.clone() v = ev.clone() manifold = geoopt.Stereographic().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.StereographicExact().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def origin( self, *size, dtype=None, device=None, seed=42 ) -> "geoopt.ManifoldTensor": """ Zero point origin. Parameters ---------- size : shape the desired shape device : torch.device the desired device dtype : torch.dtype the desired dtype seed : int ignored Returns ------- ManifoldTensor """ self._assert_check_shape(size2shape(*size), "x") return geoopt.ManifoldTensor( torch.zeros(*size, dtype=dtype, device=device), manifold=self )
def unary_case(manifold): shape = shapes[type(manifold)] np.random.seed(42) torch.manual_seed(43) if type(manifold) in mannopt: manopt_manifold = mannopt[type(manifold)](*shape) rand = manopt_manifold.rand().astype("float64") x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold) else: manopt_manifold = None x = geoopt.ManifoldTensor( torch.randn(shape, dtype=torch.float64) * 0.1, manifold=manifold ) ex = geoopt.ManifoldTensor(torch.randn_like(x), manifold=manifold) v = x.proju(torch.randn_like(x)) ev = torch.randn_like(x) return UnaryCase(shape, x, ex, v, ev, manifold, manopt_manifold)
def euclidean_stiefel_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.EuclideanStiefel] ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) u, _, v = torch.linalg.svd(ex, full_matrices=False) x = torch.einsum("...ik,...kj->...ij", u, v) nonsym = x.t() @ ev v = ev - x @ (nonsym + nonsym.t()) / 2 manifold = geoopt.manifolds.EuclideanStiefel() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.manifolds.EuclideanStiefelExact() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def euclidean_stiefel_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.EuclideanStiefel] ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) u, _, v = torch.svd(ex) x = u @ v.t() nonsym = x.t() @ ev v = ev - x @ (nonsym + nonsym.t()) / 2 manifold = geoopt.manifolds.EuclideanStiefel() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.manifolds.EuclideanStiefelExact() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def euclidean_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Euclidean] ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = ex.clone() v = ev.clone() manifold = geoopt.Euclidean(ndim=1) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def test_pickle1(): t = torch.ones(10) p = geoopt.ManifoldTensor(t, manifold=geoopt.Sphere()) with tempfile.TemporaryDirectory() as path: torch.save(p, os.path.join(path, "tens.t7")) p1 = torch.load(os.path.join(path, "tens.t7")) assert isinstance(p1, geoopt.ManifoldTensor) assert p.stride() == p1.stride() assert p.storage_offset() == p1.storage_offset() assert p.requires_grad == p1.requires_grad np.testing.assert_allclose(p.detach(), p1.detach()) assert isinstance(p.manifold, type(p1.manifold))
def test_stiefel_3d(): tens1 = geoopt.ManifoldTensor(2, 10, 20, manifold=geoopt.Stiefel()).normal_().proj_() vect1 = tens1.proju(torch.randn(*tens1.shape)) t = torch.randn(tens1.shape[0]) newt = tens1.retr(vect1, t) newt_manual = list() newt_manual.append(tens1.manifold.retr(tens1[0], vect1[0], t[0])) newt_manual.append(tens1.manifold.retr(tens1[1], vect1[1], t[1])) newt_manual = torch.stack(newt_manual) numpy.testing.assert_allclose(newt_manual, newt, atol=1e-5) numpy.testing.assert_allclose(newt, tens1.manifold.projx(newt), atol=1e-5)
def canonical_stiefel_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.CanonicalStiefel] ex = torch.randn(*shape) ev = torch.randn(*shape) u, _, v = torch.linalg.svd(ex, full_matrices=False) x = torch.einsum("...ik,...kj->...ij", u, v) v = ev - x @ ev.t() @ x manifold = geoopt.manifolds.CanonicalStiefel() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def canonical_stiefel_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.CanonicalStiefel] ex = torch.randn(*shape) ev = torch.randn(*shape) u, _, v = torch.svd(ex) x = u @ v.t() v = ev - x @ ev.t() @ x manifold = geoopt.manifolds.CanonicalStiefel() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def sphere_subspace_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] subspace = torch.rand(shape[-1], 2, dtype=torch.float64) Q, _ = geoopt.linalg.batch_linalg.qr(subspace, "reduced") P = Q @ Q.t() ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = (ex @ P.t()) / torch.norm(ex @ P.t()) v = (ev - (x @ ev) * x) @ P.t() manifold = geoopt.Sphere(intersection=subspace) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact(intersection=subspace) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def spd_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.SymmetricPositiveDefinite] ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = geoopt.linalg.batch_linalg.sym_funcm( geoopt.linalg.batch_linalg.sym(ex), torch.abs) v = geoopt.linalg.batch_linalg.sym(ev) manifold = geoopt.SymmetricPositiveDefinite(2) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def sphere_compliment_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] complement = torch.rand(shape[-1], 1, dtype=torch.float64) Q, _ = geoopt.linalg.batch_linalg.qr(complement, "reduced") P = -Q @ Q.transpose(-1, -2) P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1 ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = (ex @ P.t()) / torch.norm(ex @ P.t()) v = (ev - (x @ ev) * x) @ P.t() manifold = geoopt.Sphere(complement=complement) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact(complement=complement) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case