def forward(self, x, c=None): if c is None: c = self.c mv = pmath.mobius_matvec(self.weight, x, c=c) if self.bias is None: return pmath.project(mv, c=c) else: bias = pmath.expmap0(self.bias, c=c) return pmath.project(pmath.mobius_add(mv, bias), c=c)
def full_conv(channels, kers_full_weight, c=None, padding=0): bs, c_in, m1, m2 = channels.size() c_out, _, _, k = kers_full_weight.size() out_mat = None # torch.zeros(bs, c_out, m1-k+1 + 2*padding, m2-k+1 + 2*padding).cuda() for j in range(c_in): temp_ker = ker_by_channel(channels[:, j, :, :], kers_full_weight[:, j, :, :], c=c, padding=padding) # temp_ker : bs * c_out x (m-k+1)^2 if j == 0: out_mat = temp_ker else: out_mat = pmath.mobius_add(out_mat, temp_ker, c=c) out_mat = pmath.project(out_mat, c=c) return out_mat.view(bs, c_out, m1-k+1 + 2 * padding, m2-k+1 + 2*padding)
def forward(self, x, c=None): if c is None: c = self.c # do cast back x to R^n, do conv, then cast the result back to H space # x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) out = full_conv(x, self.weight, c=c, padding=self.padding) # out = pmath.expmap0(out.view(out.size(0) * out.size(1), -1), c=c).view(out.size()) # now add the H^n bias if self.bias is None: return pmath.project(out.view(out.size(0) * out.size(1), -1), c=c).view(out.size()) else: bias = pmath.expmap0(self.bias, c=c) bias = bias.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(out) # print dimensions # print(out.size()) # print(bias.size()) # conventional vector normalization interm = pmath.mobius_add(out.contiguous().view(out.size(0) * out.size(1), -1), bias.contiguous().view(bias.size(0) * bias.size(1), -1), c=c).view(out.size()) normed = pmath.project(interm.view(interm.size(0) * interm.size(1), -1), c=c).view(interm.size()) return normed
def forward(self, x1, x2, c=None): if c is None: c = self.c return pmath.mobius_add(self.l1(x1), self.l2(x2), c=c)