def forward(self, x, c=None):
     if c is None:
         c = self.c
     mv = pmath.mobius_matvec(self.weight, x, c=c)
     if self.bias is None:
         return pmath.project(mv, c=c)
     else:
         bias = pmath.expmap0(self.bias, c=c)
         return pmath.project(pmath.mobius_add(mv, bias), c=c)
def full_conv(channels, kers_full_weight, c=None, padding=0):
    bs, c_in, m1, m2 = channels.size()
    c_out, _, _, k = kers_full_weight.size()
    out_mat = None # torch.zeros(bs, c_out, m1-k+1 + 2*padding, m2-k+1 + 2*padding).cuda()

    for j in range(c_in):
        temp_ker = ker_by_channel(channels[:, j, :, :], kers_full_weight[:, j, :, :], c=c, padding=padding)
        # temp_ker : bs * c_out x (m-k+1)^2
        if j == 0:
            out_mat = temp_ker
        else:
            out_mat = pmath.mobius_add(out_mat, temp_ker, c=c)
            out_mat = pmath.project(out_mat, c=c)

    return out_mat.view(bs, c_out, m1-k+1 + 2 * padding, m2-k+1 + 2*padding)
    def forward(self, x, c=None):
        if c is None:
            c = self.c

        # do cast back x to R^n, do conv, then cast the result back to H space
#         x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        out = full_conv(x, self.weight, c=c, padding=self.padding)
#         out = pmath.expmap0(out.view(out.size(0) * out.size(1), -1), c=c).view(out.size())

        # now add the H^n bias
        if self.bias is None:
            return pmath.project(out.view(out.size(0) * out.size(1), -1), c=c).view(out.size())
        else:
            bias = pmath.expmap0(self.bias, c=c)
            bias = bias.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(out)
            # print dimensions
#             print(out.size())
#             print(bias.size())
            # conventional vector normalization
            interm = pmath.mobius_add(out.contiguous().view(out.size(0) * out.size(1), -1), bias.contiguous().view(bias.size(0) * bias.size(1), -1), c=c).view(out.size())
            normed = pmath.project(interm.view(interm.size(0) * interm.size(1), -1), c=c).view(interm.size())
            return normed
 def forward(self, x1, x2, c=None):
     if c is None:
         c = self.c
     return pmath.mobius_add(self.l1(x1), self.l2(x2), c=c)