def test_weighted_centroids_zeros(method, train, zero): ball = geoopt.PoincareBall() weights = torch.zeros(2, 5, 17) layer = geoopt_layers.poincare.WeightedPoincareCentroids( 7, 17, method=method, ball=ball, zero=zero ).train(train) out = layer(weights) assert out.shape == (2, 5, 7) ball.assert_check_point_on_manifold(out)
def test_dist_centroids_2d_multi(squared, train, zero): ball = geoopt.PoincareBall() point = ball.random(2, 3, 5, 5, 2).permute(0, 1, 4, 2, 3) layer = geoopt_layers.poincare.Distance2PoincareCentroids2d( centroid_shape=2, num_centroids=10, ball=ball, squared=squared, zero=zero ).train(train) out = layer(point) assert not torch.isnan(out).any() assert out.shape == (2, 3, 10, 5, 5)
def test_linear_new_ball1_origin(): ball = geoopt.PoincareBall() point = ball.random(2, 3, 5) layer = geoopt_layers.poincare.MobiusLinear( 5, 5, ball=ball, ball_out=ball, learn_origin=True ) out = layer(point) ball.assert_check_point_on_manifold(out)
def test_batch_norm(bias, train): ball = geoopt.PoincareBall() point = ball.random(2, 5) layer = geoopt_layers.poincare.MobiusBatchNorm(5, ball=ball, bias=bias) layer.train(train) out = layer(point) assert out.shape == (2, 5) ball.assert_check_point_on_manifold(out)
def test_random_init_mobius_conv(): ball = geoopt.PoincareBall() conv = geoopt_layers.poincare.MobiusConv2d( 5, 3, dim_out=7, points_in=2, points_out=4, ball=ball ) points = ball.random(3, 2, 3, 3, 5).permute(0, 1, 4, 2, 3) out = conv(points) assert out.shape == (3, 4, 7, 1, 1) ball.assert_check_point_on_manifold(out.permute(0, 1, 3, 4, 2))
def test_product(): manifold = geoopt.ProductManifold( (geoopt.Sphere(), 10), (geoopt.PoincareBall(), 3), (geoopt.Stiefel(), (20, 2)), (geoopt.Euclidean(), 43), ) sample = manifold.origin(20, manifold.n_elements) manifold.assert_check_point_on_manifold(sample)
def test_batch_norm_2d_multi(bias, train): ball = geoopt.PoincareBall() point = ball.random(2, 3, 5, 5, 2).permute(0, 1, 4, 2, 3) layer = geoopt_layers.poincare.MobiusBatchNorm2d((3, 2), ball=ball, bias=bias) layer.train(train) out = layer(point) assert out.shape == (2, 3, 2, 5, 5) ball.assert_check_point_on_manifold(out.permute(0, 1, 3, 4, 2))
def __init__(self, c, args): super(DualDecoder, self).__init__(c) self.manifold = getattr(manifolds, args.manifold)() self.in_features = args.dim act = getattr(F, args.act) if args.dataset == 'pubmed': self.cls_e = GATConv(self.in_features, args.n_classes, 8, False, args.alpha, args.dropout, args.bias, lambda x: x) self.cls_h = HGATConv(self.manifold, self.in_features, args.dim, 8, False, args.alpha, args.dropout, args.bias, act, atten=args.atten, dist=args.dist) else: self.cls_e = GATConv(self.in_features, args.n_classes, 1, args.concat, args.alpha, args.dropout, args.bias, lambda x: x) self.cls_h = HGATConv(self.manifold, self.in_features, args.dim, 1, args.concat, args.alpha, args.dropout, args.bias, act, atten=args.atten, dist=args.dist) self.in_features = args.dim self.out_features = args.n_classes self.c = c self.ball = ball = geoopt.PoincareBall(c=c) self.sphere = sphere = geoopt.manifolds.Sphere() self.scale = nn.Parameter(torch.zeros(self.out_features)) point = torch.randn(self.out_features, self.in_features) / 4 point = pmath.expmap0(point.to(args.device), c=c) tangent = torch.randn(self.out_features, self.in_features) self.point = geoopt.ManifoldParameter(point, manifold=ball) with torch.no_grad(): self.tangent = geoopt.ManifoldParameter(tangent, manifold=sphere).proj_() self.decoder_name = 'DualDecoder' '''prob weight''' self.w_e = nn.Linear(args.n_classes, 1, bias=False) self.w_h = nn.Linear(args.dim, 1, bias=False) self.drop_e = args.drop_e self.drop_h = args.drop_h self.reset_param()
def test_weighted_centroids_2d_multi(method, train, zero): ball = geoopt.PoincareBall() point = ball.random(2, 3, 5, 5, 7).permute(0, 1, 4, 2, 3) layer = geoopt_layers.poincare.WeightedPoincareCentroids2d( 2, 7, method=method, ball=ball, zero=zero ).train(train) out = layer(point) assert out.shape == (2, 3, 2, 5, 5) ball.assert_check_point_on_manifold(out.permute(0, 1, 3, 4, 2))
def test_wrapped_normal_StereographicProductManifold(): manifold = geoopt.StereographicProductManifold( (geoopt.PoincareBall(), 2), (geoopt.SphereProjection(), 2), (geoopt.Stereographic(), 2), ) mean = manifold.random(6) point = manifold.wrapped_normal(4, 6, mean=mean) manifold.assert_check_point_on_manifold(point) assert point.manifold is manifold
def __init__(self, input_size, hidden_size): super(hyperRNN, self).__init__() self.input_size = input_size self.hidden_size = hidden_size k = (1 / hidden_size)**0.5 self.w = gt.ManifoldParameter(gt.ManifoldTensor(hidden_size, 2, hidden_size, 2).uniform_(-k, k)) self.u = gt.ManifoldParameter(gt.ManifoldTensor(input_size, 2, hidden_size, 2).uniform_(-k, k)) self.b = gt.ManifoldParameter(gt.ManifoldTensor(hidden_size, 2, manifold=gt.PoincareBall()).zero_())
def test_expmap2d(): ball = geoopt.PoincareBall() point = torch.randn(2, 2, 5, 5) layer = geoopt_layers.Expmap2d(ball, origin_shape=(2,)) layer1 = geoopt_layers.Logmap2d(ball, origin_shape=(2,)) out = layer(point) assert out.shape == (2, 2, 5, 5) ball.assert_check_point_on_manifold(out.permute(0, 2, 3, 1)) reverse = layer1(out) np.testing.assert_allclose(point.detach(), reverse.detach(), atol=1e-4)
def test_poincare_mean_scatter(): ball = geoopt.PoincareBall() points = ball.random(10, 5, std=1 / 5 ** 0.5) means = geoopt_layers.poincare.math.poincare_mean_scatter( points, index=torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), ball=ball ) assert means.shape == (2, 5) mean_1 = geoopt_layers.poincare.math.poincare_mean(points[:5], ball=ball) mean_2 = geoopt_layers.poincare.math.poincare_mean(points[5:], ball=ball) np.testing.assert_allclose(means[0], mean_1, atol=1e-5) np.testing.assert_allclose(means[1], mean_2, atol=1e-5)
def test_average_equals_conv(mm): ball = geoopt.PoincareBall() conv = geoopt_layers.poincare.MobiusConv2d(5, 3, dim_out=5, ball=ball, matmul=mm) with torch.no_grad(): if mm: torch.nn.init.eye_(conv.weight_mm) conv.weight_avg.fill_(1) points = ball.random(1, 1, 3, 3, 5) avg1 = geoopt_layers.poincare.math.poincare_mean(points, dim=-1, ball=ball) avg2 = conv(points.permute(0, 1, 4, 2, 3)).detach().view(-1) np.testing.assert_allclose(avg1, avg2, atol=1e-5, rtol=1e-5) ball.assert_check_point_on_manifold(avg2)
def test_poincare_mean_scatter_tangent(linkomb, method): ball = geoopt.PoincareBall() points = ball.random(10, 5, std=1 / 5 ** 0.5) means = geoopt_layers.poincare.math.poincare_mean_scatter( points, index=torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), ball=ball, method=method, lincomb=linkomb, ) assert means.shape == (2, 5) ball.assert_check_point_on_manifold(means)
def __init__(self, in_features, out_features, c=1.0): super().__init__() self.in_features = in_features self.out_features = out_features self.ball = ball = geoopt.PoincareBall(c=c) self.sphere = sphere = geoopt.manifolds.Sphere() self.scale = torch.nn.Parameter(torch.zeros(out_features)) point = torch.randn(out_features, in_features) / 4 point = pmath.expmap0(point, c=c) tangent = torch.randn(out_features, in_features) self.point = geoopt.ManifoldParameter(point, manifold=ball) with torch.no_grad(): self.tangent = geoopt.ManifoldParameter(tangent, manifold=sphere).proj_()
def __init__(self, user_num, item_num, embedding_dim, c=1): super().__init__() manifold = geoopt.PoincareBall(c=c) self.user_embeddings = LookupEmbedding( user_num, embedding_dim, manifold=manifold ) self.item_embeddings = LookupEmbedding( item_num, embedding_dim, manifold=manifold ) self.sim_layer = HyperbolicDistanceLayer(c=c) self.post_sim = nn.Linear(1, 1)
def __init__(self, output_dim, input_dims, second_input_dim=None, third_input_dim=None, nonlin=None): super(MobiusConcat, self).__init__() b_input_dims = second_input_dim if second_input_dim is not None else input_dims self.lin_a = MobiusLinear(input_dims, output_dim, bias=False, nonlin=nonlin) self.lin_b = MobiusLinear(b_input_dims, output_dim, bias=False, nonlin=nonlin) if third_input_dim: self.lin_c = MobiusLinear(third_input_dim, output_dim, bias=False, nonlin=nonlin) self.ball = gt.PoincareBall() b = torch.randn(output_dim) * 1e-5 self.bias = gt.ManifoldParameter(pmath.expmap0(b, k=self.ball.k), manifold=self.ball)
def __init__(self, input_size, hidden_size): super(MobiusRNN, self).__init__() self.ball = gt.PoincareBall() self.input_size = input_size self.hidden_size = hidden_size # k = (1 / hidden_size)**0.5 k_w = (6 / (self.hidden_size + self.hidden_size)) ** 0.5 # xavier uniform k_u = (6 / (self.input_size + self.hidden_size)) ** 0.5 # xavier uniform self.w = gt.ManifoldParameter(gt.ManifoldTensor(hidden_size, hidden_size).uniform_(-k_w, k_w)) self.u = gt.ManifoldParameter(gt.ManifoldTensor(input_size, hidden_size).uniform_(-k_u, k_u)) bias = torch.randn(hidden_size) * 1e-5 self.b = gt.ManifoldParameter(pmath.expmap0(bias, k=self.ball.k), manifold=self.ball)
def test_dist_planes_2d_multi(squared, train, zero, signed): ball = geoopt.PoincareBall() point = ball.random(2, 3, 5, 5, 2).permute(0, 1, 4, 2, 3) layer = geoopt_layers.poincare.Distance2PoincareHyperplanes2d( plane_shape=2, num_planes=10, ball=ball, squared=squared, zero=zero, signed=signed, ).train(train) out = layer(point) assert not torch.isnan(out).any() assert out.shape == (2, 3, 10, 5, 5)
def test_remap_provided_origin(): sphere = geoopt.Sphere() poincare = geoopt.PoincareBall() point = sphere.random(1, 10) func = torch.nn.Linear(10, 13) layer = geoopt_layers.RemapLambda( func, source_manifold=sphere, target_manifold=poincare, source_origin=sphere.origin(10), target_origin=poincare.origin(13), ) out = layer(point) poincare.assert_check_point_on_manifold(out)
def __init__(self, *args, hyperbolic_input=True, hyperbolic_bias=True, nonlin=None, c=1.0, **kwargs): super().__init__(*args, **kwargs) self.ball = gt.PoincareBall(c=c) if self.bias is not None: if hyperbolic_bias: self.bias = gt.ManifoldParameter(self.bias, manifold=self.ball) with torch.no_grad(): self.bias.set_(pmath.expmap0(self.bias.normal_() * 1e-3, k=self.ball.k)) with torch.no_grad(): fin, fout = self.weight.size() k = (6 / (fin + fout)) ** 0.5 # xavier uniform self.weight.uniform_(-k, k) self.hyperbolic_bias = hyperbolic_bias self.hyperbolic_input = hyperbolic_input self.nonlin = nonlin
def poincare_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.PoincareBall] ex = torch.randn(*shape, dtype=torch.float64) / 3 ev = torch.randn(*shape, dtype=torch.float64) / 3 x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex) ex = x.clone() v = ev.clone() manifold = geoopt.PoincareBall().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.PoincareBallExact().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def __init__(self, in_features, out_features, k=-1.0, fp64_hyper=True): k = torch.tensor(k) super().__init__() self.in_features = in_features self.out_features = out_features self.ball = ball = geoopt.PoincareBall(c=k.abs()) self.sphere = sphere = geoopt.manifolds.Sphere() self.scale = torch.nn.Parameter(torch.zeros(out_features)) point = torch.randn(out_features, in_features) / 4 point = gmath.expmap0(point, k=k) tangent = torch.randn(out_features, in_features) self.point = geoopt.ManifoldParameter(point, manifold=ball) self.fp64_hyper = fp64_hyper with torch.no_grad(): self.tangent = geoopt.ManifoldParameter(tangent, manifold=sphere).proj_()
def __init__(self, in_features, out_features, c=1.0): """ :param in_features: number of dimensions of the input :param out_features: number of classes """ super().__init__() self.in_features = in_features self.out_features = out_features self.ball = gt.PoincareBall(c=c) points = torch.randn(out_features, in_features) * 1e-5 points = pmath.expmap0(points, k=self.ball.k) self.p_k = gt.ManifoldParameter(points, manifold=self.ball) tangent = torch.Tensor(out_features, in_features) stdv = (6 / (out_features + in_features)) ** 0.5 # xavier uniform torch.nn.init.uniform_(tangent, -stdv, stdv) self.a_k = torch.nn.Parameter(tangent)
def __init__( self, input_size, hidden_size, num_layers=1, bias=True, nonlin=None, hyperbolic_input=True, hyperbolic_hidden_state0=True, c=1.0, ): super().__init__() self.ball = geoopt.PoincareBall(c=c) self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.weight_ih = torch.nn.ParameterList( [ torch.nn.Parameter( torch.Tensor(3 * hidden_size, input_size if i == 0 else hidden_size) ) for i in range(num_layers) ] ) self.weight_hh = torch.nn.ParameterList( [ torch.nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size)) for _ in range(num_layers) ] ) if bias: biases = [] for i in range(num_layers): bias = torch.randn(3, hidden_size) * 1e-5 bias = geoopt.ManifoldParameter( pmath.expmap0(bias, c=self.ball.c), manifold=self.ball ) biases.append(bias) self.bias = torch.nn.ParameterList(biases) else: self.register_buffer("bias", None) self.nonlin = nonlin self.hyperbolic_input = hyperbolic_input self.hyperbolic_hidden_state0 = hyperbolic_hidden_state0 self.reset_parameters()
def __init__(self, args, vocabs, word2vec=None): self.args = args super(Model, self).__init__() self.word_embed_manifold = gt.PoincareBall( ) if args.embedding_metric == cs.HY else gt.Euclidean() self.train_word_embeds = args.train_word_embeds == 1 self.word_lut = self.init_lut( word2vec, len(vocabs[cs.TOKEN_VOCAB].label2wordvec_idx), args.word_emb_size) self.word_lut.requires_grad = self.train_word_embeds self.concat_dropout = nn.Dropout(p=args.concat_dropout) self.classif_dropout = nn.Dropout(p=args.classif_dropout) # encoders self.mention_encoder = MentionEncoder(vocabs[cs.CHAR_VOCAB], args) self.context_encoder = ContextEncoder(args) men_dim = self.mention_encoder.mention_output_dim char_dim = self.mention_encoder.char_output_dim ctx_dim = self.context_encoder.ctx_output_dim # ctx concat and attn ctx_concat_layer = hnn.MobiusConcat if args.encoder_metric == cs.HY else hnn.EuclConcat self.ctx_concat = ctx_concat_layer(ctx_dim * 2, ctx_dim) self.ctx_attn = DistanceAttention(args, args.context_len * 2 + 2, ctx_dim * 2) # full concat of mention and context input_classif_dim = men_dim + char_dim + ctx_dim * 2 full_concat_layer = hnn.MobiusConcat if args.concat_metric == cs.HY else hnn.EuclConcat self.full_concat = full_concat_layer(input_classif_dim, men_dim, second_input_dim=ctx_dim * 2, third_input_dim=char_dim) # classifier classifier_layer = hnn.MobiusMLR if args.mlr_metric == cs.HY else hnn.EuclMLR self.classifier = classifier_layer(input_classif_dim, vocabs[cs.TYPE_VOCAB].size()) self.attn_to_concat_map = define_mapping(args.attn_metric, args.concat_metric, args.c) self.concat_to_mlr_map = define_mapping(args.concat_metric, args.mlr_metric, args.c)
def test_adam_poincare(): torch.manual_seed(44) ideal = torch.tensor([0.5, 0.5]) start = torch.randn(2) / 2 start = geoopt.manifolds.poincare.math.expmap0(start, c=1.0) start = geoopt.ManifoldParameter(start, manifold=geoopt.PoincareBall()) def closure(): optim.zero_grad() loss = geoopt.manifolds.poincare.math.dist(start, ideal)**2 loss.backward() return loss.item() optim = geoopt.optim.RiemannianAdam([start], lr=1e-2) for _ in range(2000): optim.step(closure) np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
def create_ball(ball=None, c=None): """ Helper to create a PoincareBall. Sometimes you may want to share a manifold across layers, e.g. you are using scaled PoincareBall. In this case you will require same curvature parameters for different layers or end up with nans. Parameters ---------- ball : geoopt.PoincareBall c : float Returns ------- geoopt.PoincareBall """ if ball is None: assert c is not None, "curvature of the ball should be explicitly specified" ball = geoopt.PoincareBall(c) # else trust input return ball
def __init__(self, *args, hyperbolic_input=True, hyperbolic_bias=True, nonlin=None, c=1.0, **kwargs): super().__init__(*args, **kwargs) if self.bias is not None: if hyperbolic_bias: self.ball = manifold = geoopt.PoincareBall(c=c) self.bias = geoopt.ManifoldParameter(self.bias, manifold=manifold) with torch.no_grad(): self.bias.set_(pmath.expmap0(self.bias.normal_() / 4, c=c)) with torch.no_grad(): self.weight.normal_(std=1e-2) self.hyperbolic_bias = hyperbolic_bias self.hyperbolic_input = hyperbolic_input self.nonlin = nonlin