def test_fails_Stiefel(): with pytest.raises(ValueError): manifold = geoopt.Stiefel() manifold.random_naive(()) with pytest.raises(ValueError): manifold = geoopt.Stiefel() manifold.random_naive((5, 10))
def test_fails_Stiefel(): with pytest.raises(ValueError): manifold = geoopt.Stiefel() manifold.origin(()) with pytest.raises(ValueError): manifold = geoopt.Stiefel() manifold.origin((5, 10))
def product_case(): torch.manual_seed(42) ex = [ torch.randn(10), torch.randn(3) / 10, torch.randn(3, 2), torch.randn(()) ] ev = [ torch.randn(10), torch.randn(3) / 10, torch.randn(3, 2), torch.randn(()) ] manifolds = [ geoopt.Sphere(), geoopt.PoincareBall(), geoopt.Stiefel(), geoopt.Euclidean(), ] x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))] v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))] product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape) for i in range(len(ex)))) yield UnaryCase( manifold_shapes[geoopt.ProductManifold], product_manifold.pack_point(*x), product_manifold.pack_point(*ex), product_manifold.pack_point(*v), product_manifold.pack_point(*ev), product_manifold, ) # + 1 case without stiefel torch.manual_seed(42) ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())] ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())] manifolds = [ geoopt.Sphere(), geoopt.PoincareBall(), # geoopt.Stiefel(), geoopt.Euclidean(), ] x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))] v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))] product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape) for i in range(len(ex)))) yield UnaryCase( manifold_shapes[geoopt.ProductManifold], product_manifold.pack_point(*x), product_manifold.pack_point(*ex), product_manifold.pack_point(*v), product_manifold.pack_point(*ev), product_manifold, )
def test_product(): manifold = geoopt.ProductManifold( (geoopt.Sphere(), 10), (geoopt.PoincareBall(), 3), (geoopt.Stiefel(), (20, 2)), (geoopt.Euclidean(), 43), ) sample = manifold.random(20, manifold.n_elements) manifold.assert_check_point_on_manifold(sample)
def test_stiefel_3d(): tens1 = geoopt.ManifoldTensor(2, 10, 20, manifold=geoopt.Stiefel()).normal_().proj_() vect1 = tens1.proju(torch.randn(*tens1.shape)) t = torch.randn(tens1.shape[0]) newt = tens1.retr(vect1, t) newt_manual = list() newt_manual.append(tens1.manifold.retr(tens1[0], vect1[0], t[0])) newt_manual.append(tens1.manifold.retr(tens1[1], vect1[1], t[1])) newt_manual = torch.stack(newt_manual) numpy.testing.assert_allclose(newt_manual, newt, atol=1e-5) numpy.testing.assert_allclose(newt, tens1.manifold.projx(newt), atol=1e-5)
def __init__(self, in_features, out_features): super().__init__() assert out_features <= in_features self.in_features = in_features self.out_features = out_features self.shape = (self.in_features, self.out_features) self.manifold = geoopt.Stiefel() self.weight = geoopt.ManifoldParameter(torch.empty(*self.shape), manifold=self.manifold) self.reset_parameters()
def __init__(self, in_features, out_features,device, bias=False): super(StiefelLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = geoopt.ManifoldParameter( data=torch.Tensor(out_features, in_features), manifold=geoopt.Stiefel() ) if bias: self.bias = torch.nn.Parameter(torch.Tensor(out_features).to(device)) else: self.register_parameter('bias', None) self.reset_parameters() self.weight.data=self.weight.data.to(device)
def test_random_Stiefel(): manifold = geoopt.Stiefel() point = manifold.random_naive(3, 10, 10) manifold.assert_check_point_on_manifold(point) assert point.manifold is manifold
def __init__(self, mu, sigma): super().__init__() self.d = torch.distributions.Normal(mu, sigma) self.x = geoopt.ManifoldParameter(torch.randn_like(mu), manifold=geoopt.Stiefel())
def test_stiefel_2d(): tens1 = geoopt.ManifoldTensor(10, 20, manifold=geoopt.Stiefel()).normal_().proj_() vect1 = tens1.proju(torch.randn(*tens1.shape)) newt = tens1.retr(vect1, 1.) numpy.testing.assert_allclose(newt, tens1.manifold.projx(newt), atol=1e-5)