def forward(self, x): if self.train_x: xp = pmath.project(pmath.expmap0(self.xp, c=self.c), c=self.c) return self.grad_fix( pmath.project(pmath.expmap(xp, x, c=self.c), c=self.c)) return self.grad_fix( pmath.project(pmath.expmap0(x, c=self.c), c=self.c))
def forward(self, x, c=None): if c is None: c = self.c mv = pmath.mobius_matvec(self.weight, x, c=c) if self.bias is None: return pmath.project(mv, c=c) else: bias = pmath.expmap0(self.bias, c=c) return pmath.project(pmath.mobius_add(mv, bias), c=c)
def ker_by_channel(channel, ker, c=None, padding=0): channel = nn.ConstantPad2d(padding, 0)(channel) c_out, kernel_size, _ = ker.size() bs, m1, m2 = channel.size() channel = pmath.logmap0(channel.view(bs, -1), c=c).view(bs, m1, m2) channel = nn.functional.conv2d(channel.unsqueeze(1), ker.unsqueeze(1), bias=None).view(bs * c_out, -1) channel = pmath.expmap0(channel, c=c) channel = pmath.project(channel, c=c) return channel
def forward(self, x, c=None): if c is None: c = self.c x_eucl = pmath.logmap0(x, c=c) out = self.conv(x_eucl) x_hyp = pmath.expmap0(out, c=c) x_hyp_proj = pmath.project(x_hyp, c=c) return x_hyp_proj
def forward(self, x, c=None): if c is None: c = self.c # note that logmap and exmap are happening with respect to origin x_eucl = pmath.logmap0(x, c=c) out = self.lin(x_eucl) x_hyp = pmath.expmap0(out, c=c) x_hyp_proj = pmath.project(x_hyp, c=c) return x_hyp_proj
def forward(self, x, c=None): if c is None: c = self.c # do cast back x to R^n, do conv, then cast the result back to H space # x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) out = full_conv(x, self.weight, c=c, padding=self.padding) # out = pmath.expmap0(out.view(out.size(0) * out.size(1), -1), c=c).view(out.size()) # now add the H^n bias if self.bias is None: return pmath.project(out.view(out.size(0) * out.size(1), -1), c=c).view(out.size()) else: bias = pmath.expmap0(self.bias, c=c) bias = bias.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(out) # print dimensions # print(out.size()) # print(bias.size()) # conventional vector normalization interm = pmath.mobius_add(out.contiguous().view(out.size(0) * out.size(1), -1), bias.contiguous().view(bias.size(0) * bias.size(1), -1), c=c).view(out.size()) normed = pmath.project(interm.view(interm.size(0) * interm.size(1), -1), c=c).view(interm.size()) return normed
def full_conv(channels, kers_full_weight, c=None, padding=0): bs, c_in, m1, m2 = channels.size() c_out, _, _, k = kers_full_weight.size() out_mat = None # torch.zeros(bs, c_out, m1-k+1 + 2*padding, m2-k+1 + 2*padding).cuda() for j in range(c_in): temp_ker = ker_by_channel(channels[:, j, :, :], kers_full_weight[:, j, :, :], c=c, padding=padding) # temp_ker : bs * c_out x (m-k+1)^2 if j == 0: out_mat = temp_ker else: out_mat = pmath.mobius_add(out_mat, temp_ker, c=c) out_mat = pmath.project(out_mat, c=c) return out_mat.view(bs, c_out, m1-k+1 + 2 * padding, m2-k+1 + 2*padding)
def forward(self, x): hyp_result = [] klein_result = [] for i in range(self.inst_num): med = x[:, i, :, :, :] x_size = x.size() med_input = med.view(x_size[0], x_size[2], x_size[3], x_size[4]) med_out = self.resnet18(med_input) hyp_med_out = self.tp(med_out) hyp_med_out = torch.reshape(hyp_med_out, (hyp_med_out.shape[0], 1, -1)) hyp_result.append(hyp_med_out) klein_med_out = pmath.p2k(hyp_med_out, c=self.c) klein_med_out = pmath.project(klein_med_out, c=self.c) klein_med_out = torch.reshape(klein_med_out, (klein_med_out.shape[0], 1, -1)) klein_result.append(klein_med_out) hyp_img_embed = torch.cat(hyp_result, dim=1) klein_img_embed = torch.cat(klein_result, dim=1) weight_result = [] lorenz_f_result = [] for i in range(self.inst_num): hyp_img_embed_input_med = hyp_img_embed[:, i, :] hyp_img_embed_size = hyp_img_embed.size() hyp_img_embed_input = hyp_img_embed_input_med.view( hyp_img_embed_size[0], -1) weight = self.att_weight(hyp_img_embed_input) klein_img_embed_input_med = klein_img_embed[:, i, :] klein_img_embed_size = klein_img_embed.size() klein_img_embed_input = klein_img_embed_input_med.view( klein_img_embed_size[0], -1) lorenz_f = pmath.lorenz_factor(klein_img_embed_input, c=self.c, dim=1, keepdim=True) weight_result.append(weight) lorenz_f_result.append(lorenz_f) embed_weight = torch.cat(weight_result, 1) out_weight = embed_weight lamb = torch.cat(lorenz_f_result, 1) alpha = torch.nn.Softmax(dim=1)(embed_weight) alpha_lamb = alpha * lamb alpha_lamb_sum = torch.sum(alpha_lamb, dim=1) alpha_lamb_sum = alpha_lamb_sum.unsqueeze(dim=1) alpha_lamb_norm = alpha_lamb / alpha_lamb_sum alpha_lamb_norm = torch.reshape(alpha_lamb_norm, (alpha_lamb_norm.shape[0], 1, -1)) rep = torch.bmm(alpha_lamb_norm, klein_img_embed) rep = torch.reshape(rep, (rep.shape[0], -1)) rep = pmath.project(rep, c=self.c) rep = pmath.k2p(rep, c=self.c) label = self.label_pred(rep) return label, out_weight
def forward(self, x): if self.train_x: xp = pmath.project(pmath.expmap0(self.xp, c=self.c), c=self.c) return pmath.logmap(xp, x, c=self.c) return pmath.logmap0(x, c=self.c)
def forward(self, x): #self.tp.c.data = self.tp.c.data.clamp(min=c_limit) hyp_result = [] klein_result = [] for i in range(self.inst_num): med = x[:, i, :, :, :] x_size = x.size() med_input = med.view(x_size[0], x_size[2], x_size[3], x_size[4]) med_out = self.lenet(med_input) hyp_med_out = self.tp(med_out) if torch.isnan(hyp_med_out).any(): print('error: hyp_med_out') print(med_out.data.cpu()) print(hyp_med_out.data.cpu()) print(self.tp) sys.exit() hyp_med_out = torch.reshape(hyp_med_out, (hyp_med_out.shape[0], 1, -1)) hyp_result.append(hyp_med_out) klein_med_out = pmath.p2k(hyp_med_out, c=self.c) klein_med_out = pmath.project(klein_med_out, c=self.c) if torch.isnan(klein_med_out).any(): print('klein_med_out') klein_med_out = torch.reshape(klein_med_out, (klein_med_out.shape[0], 1, -1)) klein_result.append(klein_med_out) hyp_img_embed = torch.cat(hyp_result, dim=1) klein_img_embed = torch.cat(klein_result, dim=1) weight_result = [] lorenz_f_result = [] for i in range(self.inst_num): hyp_img_embed_input_med = hyp_img_embed[:, i, :] hyp_img_embed_size = hyp_img_embed.size() hyp_img_embed_input = hyp_img_embed_input_med.view( hyp_img_embed_size[0], -1) weight = self.att_weight(hyp_img_embed_input) if torch.isnan(weight).any(): print('weight') klein_img_embed_input_med = klein_img_embed[:, i, :] klein_img_embed_size = klein_img_embed.size() klein_img_embed_input = klein_img_embed_input_med.view( klein_img_embed_size[0], -1) lorenz_f = pmath.lorenz_factor(klein_img_embed_input, c=self.c, dim=1, keepdim=True) if torch.isnan(lorenz_f).any(): print('lorenz_f') weight_result.append(weight) lorenz_f_result.append(lorenz_f) embed_weight = torch.cat(weight_result, 1) out_weight = embed_weight lamb = torch.cat(lorenz_f_result, 1) alpha = torch.nn.Softmax(dim=1)(embed_weight) alpha_lamb = alpha * lamb alpha_lamb_sum = torch.sum(alpha_lamb, dim=1) alpha_lamb_sum = alpha_lamb_sum.unsqueeze(dim=1) alpha_lamb_norm = alpha_lamb / alpha_lamb_sum alpha_lamb_norm = torch.reshape(alpha_lamb_norm, (alpha_lamb_norm.shape[0], 1, -1)) rep = torch.bmm(alpha_lamb_norm, klein_img_embed) rep = torch.reshape(rep, (rep.shape[0], -1)) rep = pmath.project(rep, c=self.c) if torch.isnan(rep).any(): print('klein_med_out') rep = pmath.k2p(rep, c=self.c) msi = self.msi_pred(rep) if torch.isnan(msi).any(): print('msi') return msi, out_weight
def forward(self, x, c=None): if c is None: c = self.c # do proper normalization of euclidean data x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # BLOCK 1 x = self.c1(x, c=c) # batch norm # x = pmath.logmap0(x, c=c) # x = self.b1(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # blocked relu and maxpool 2 # x = pmath.logmap0(x, c=c) # x = nn.ReLU()(x) # x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # separate relu and maxpool 2 x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = nn.ReLU()(x) x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # xbrp = x # print(f'norm after relu: {x.norm(dim=-1, keepdim=True, p=2)[0]}') x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # print(f'norm after projection relu: {x.norm(dim=-1, keepdim=True, p=2)[0]}') # print(f'diff: {xbrp[0]-x[0]}') # print(f'diff sum: {sum(sum(sum(xbrp[0]-x[0])))}') # print(f'x after relu project the same: {xbrp.equal(x)}') x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = nn.MaxPool2d(2)(x) x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # xbpp = x x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # print(f'x after pool project the same: {xbpp.equal(x)}') # # BLOCK 2 # x = self.c2(x, c=c) # # batch norm # # x = pmath.logmap0(x, c=c) # # x = self.b2(x) # # x = pmath.expmap0(x, c=c) # # x = pmath.project(x, c=c) # # blocked relu and maxpool 2 # # x = pmath.logmap0(x, c=c) # # x = nn.ReLU()(x) # # x = nn.MaxPool2d(2)(x) # # x = pmath.expmap0(x, c=c) # # x = pmath.project(x, c=c) # # separate relu and maxpool 2 # x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = nn.ReLU()(x) # x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # # BLOCK 3 # x = self.c3(x, c=c) # # batch norm # # x = pmath.logmap0(x, c=c) # # x = self.b3(x) # # x = pmath.expmap0(x, c=c) # # x = pmath.project(x, c=c) # # blocked relu and maxpool 2 # # x = pmath.logmap0(x, c=c) # # x = nn.ReLU()(x) # # x = nn.MaxPool2d(2)(x) # # x = pmath.expmap0(x, c=c) # # x = pmath.project(x, c=c) # # separate relu and maxpool 2 # x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = nn.ReLU()(x) # x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # # BLOCK 4 # x = self.c4(x, c=c) # # batch norm # # x = pmath.logmap0(x, c=c) # # x = self.b4(x) # # x = pmath.expmap0(x, c=c) # # x = pmath.project(x, c=c) # # blocked relu and maxpool 2 # # x = pmath.logmap0(x, c=c) # # x = nn.ReLU()(x) # # x = nn.MaxPool2d(2)(x) # # x = pmath.expmap0(x, c=c) # # x = pmath.project(x, c=c) # # separate relu and maxpool 2 # x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = nn.ReLU()(x) # x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # final pool x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = nn.MaxPool2d(5)(x) x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # print(x.size()), currently N x 512 x 1 x 1 # currently I believe this step may mess with the geometry # what would be a natural replacement? A: view as eucl vector, then do expmap to go back to hyperbolic space x = x.view(x.size(0), -1) x = pmath.expmap0(x, c=c) x = pmath.project(x, c=c) return x
def forward(self, x, c=None): if c is None: c = self.c # BLOCK 1 x = self.c1(x) # batch norm # x = pmath.logmap0(x, c=c) # x = self.b1(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # blocked relu and maxpool 2 # x = pmath.logmap0(x, c=c) # x = nn.ReLU()(x) # x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # separate relu and maxpool 2 # x = pmath.logmap0(x, c=c) x = nn.ReLU()(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # x = pmath.logmap0(x, c=c) x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # BLOCK 2 x = self.c2(x) # batch norm # x = pmath.logmap0(x, c=c) # x = self.b2(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # blocked relu and maxpool 2 # x = pmath.logmap0(x, c=c) # x = nn.ReLU()(x) # x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # separate relu and maxpool 2 # x = pmath.logmap0(x, c=c) x = nn.ReLU()(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # x = pmath.logmap0(x, c=c) x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # BLOCK 3 x = self.c3(x) # batch norm # x = pmath.logmap0(x, c=c) # x = self.b3(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # blocked relu and maxpool 2 # x = pmath.logmap0(x, c=c) # x = nn.ReLU()(x) # x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # separate relu and maxpool 2 # x = pmath.logmap0(x, c=c) x = nn.ReLU()(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # x = pmath.logmap0(x, c=c) x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # BLOCK 4, to hyperbolic x = self.e2p(x) x = self.c4(x, c=c) # batch norm # x = pmath.logmap0(x, c=c) # x = self.b4(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # blocked relu and maxpool 2 # x = pmath.logmap0(x, c=c) # x = nn.ReLU()(x) # x = nn.MaxPool2d(2)(x) # x = pmath.expmap0(x, c=c) # x = pmath.project(x, c=c) # separate relu and maxpool 2 x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = nn.ReLU()(x) x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = nn.MaxPool2d(2)(x) x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # final pool x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = nn.MaxPool2d(5)(x) x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size()) # print(x.size()), currently N x 512 x 1 x 1 # currently I believe this step may mess with the geometry # what would be a natural replacement? A: view as eucl vector, then do expmap to go back to hyperbolic space x = x.view(x.size(0), -1) x = pmath.expmap0(x, c=c) x = pmath.project(x, c=c) return x