def forward(self, x):

        if self.train_x:
            xp = pmath.project(pmath.expmap0(self.xp, c=self.c), c=self.c)
            return self.grad_fix(
                pmath.project(pmath.expmap(xp, x, c=self.c), c=self.c))
        return self.grad_fix(
            pmath.project(pmath.expmap0(x, c=self.c), c=self.c))
 def forward(self, x, c=None):
     if c is None:
         c = self.c
     mv = pmath.mobius_matvec(self.weight, x, c=c)
     if self.bias is None:
         return pmath.project(mv, c=c)
     else:
         bias = pmath.expmap0(self.bias, c=c)
         return pmath.project(pmath.mobius_add(mv, bias), c=c)
def ker_by_channel(channel, ker, c=None, padding=0):
    channel = nn.ConstantPad2d(padding, 0)(channel)
    c_out, kernel_size, _ = ker.size()
    bs, m1, m2 = channel.size()
    channel = pmath.logmap0(channel.view(bs, -1), c=c).view(bs, m1, m2)
    channel = nn.functional.conv2d(channel.unsqueeze(1), ker.unsqueeze(1), bias=None).view(bs * c_out, -1)
    channel = pmath.expmap0(channel, c=c)
    channel = pmath.project(channel, c=c)
    return channel
    def forward(self, x, c=None):
        if c is None:
            c = self.c

        x_eucl = pmath.logmap0(x, c=c)
        out = self.conv(x_eucl)
        x_hyp = pmath.expmap0(out, c=c)
        x_hyp_proj = pmath.project(x_hyp, c=c)

        return x_hyp_proj
    def forward(self, x, c=None):
        if c is None:
            c = self.c

        # note that logmap and exmap are happening with respect to origin
        x_eucl = pmath.logmap0(x, c=c)
        out = self.lin(x_eucl)
        x_hyp = pmath.expmap0(out, c=c)
        x_hyp_proj = pmath.project(x_hyp, c=c)

        return x_hyp_proj
    def forward(self, x, c=None):
        if c is None:
            c = self.c

        # do cast back x to R^n, do conv, then cast the result back to H space
#         x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        out = full_conv(x, self.weight, c=c, padding=self.padding)
#         out = pmath.expmap0(out.view(out.size(0) * out.size(1), -1), c=c).view(out.size())

        # now add the H^n bias
        if self.bias is None:
            return pmath.project(out.view(out.size(0) * out.size(1), -1), c=c).view(out.size())
        else:
            bias = pmath.expmap0(self.bias, c=c)
            bias = bias.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(out)
            # print dimensions
#             print(out.size())
#             print(bias.size())
            # conventional vector normalization
            interm = pmath.mobius_add(out.contiguous().view(out.size(0) * out.size(1), -1), bias.contiguous().view(bias.size(0) * bias.size(1), -1), c=c).view(out.size())
            normed = pmath.project(interm.view(interm.size(0) * interm.size(1), -1), c=c).view(interm.size())
            return normed
def full_conv(channels, kers_full_weight, c=None, padding=0):
    bs, c_in, m1, m2 = channels.size()
    c_out, _, _, k = kers_full_weight.size()
    out_mat = None # torch.zeros(bs, c_out, m1-k+1 + 2*padding, m2-k+1 + 2*padding).cuda()

    for j in range(c_in):
        temp_ker = ker_by_channel(channels[:, j, :, :], kers_full_weight[:, j, :, :], c=c, padding=padding)
        # temp_ker : bs * c_out x (m-k+1)^2
        if j == 0:
            out_mat = temp_ker
        else:
            out_mat = pmath.mobius_add(out_mat, temp_ker, c=c)
            out_mat = pmath.project(out_mat, c=c)

    return out_mat.view(bs, c_out, m1-k+1 + 2 * padding, m2-k+1 + 2*padding)
Beispiel #8
0
    def forward(self, x):
        hyp_result = []
        klein_result = []
        for i in range(self.inst_num):
            med = x[:, i, :, :, :]
            x_size = x.size()
            med_input = med.view(x_size[0], x_size[2], x_size[3], x_size[4])
            med_out = self.resnet18(med_input)
            hyp_med_out = self.tp(med_out)

            hyp_med_out = torch.reshape(hyp_med_out,
                                        (hyp_med_out.shape[0], 1, -1))
            hyp_result.append(hyp_med_out)

            klein_med_out = pmath.p2k(hyp_med_out, c=self.c)
            klein_med_out = pmath.project(klein_med_out, c=self.c)

            klein_med_out = torch.reshape(klein_med_out,
                                          (klein_med_out.shape[0], 1, -1))
            klein_result.append(klein_med_out)

        hyp_img_embed = torch.cat(hyp_result, dim=1)
        klein_img_embed = torch.cat(klein_result, dim=1)

        weight_result = []
        lorenz_f_result = []
        for i in range(self.inst_num):
            hyp_img_embed_input_med = hyp_img_embed[:, i, :]
            hyp_img_embed_size = hyp_img_embed.size()
            hyp_img_embed_input = hyp_img_embed_input_med.view(
                hyp_img_embed_size[0], -1)
            weight = self.att_weight(hyp_img_embed_input)

            klein_img_embed_input_med = klein_img_embed[:, i, :]
            klein_img_embed_size = klein_img_embed.size()
            klein_img_embed_input = klein_img_embed_input_med.view(
                klein_img_embed_size[0], -1)

            lorenz_f = pmath.lorenz_factor(klein_img_embed_input,
                                           c=self.c,
                                           dim=1,
                                           keepdim=True)

            weight_result.append(weight)
            lorenz_f_result.append(lorenz_f)

        embed_weight = torch.cat(weight_result, 1)
        out_weight = embed_weight
        lamb = torch.cat(lorenz_f_result, 1)
        alpha = torch.nn.Softmax(dim=1)(embed_weight)

        alpha_lamb = alpha * lamb
        alpha_lamb_sum = torch.sum(alpha_lamb, dim=1)
        alpha_lamb_sum = alpha_lamb_sum.unsqueeze(dim=1)
        alpha_lamb_norm = alpha_lamb / alpha_lamb_sum
        alpha_lamb_norm = torch.reshape(alpha_lamb_norm,
                                        (alpha_lamb_norm.shape[0], 1, -1))
        rep = torch.bmm(alpha_lamb_norm, klein_img_embed)
        rep = torch.reshape(rep, (rep.shape[0], -1))

        rep = pmath.project(rep, c=self.c)

        rep = pmath.k2p(rep, c=self.c)
        label = self.label_pred(rep)

        return label, out_weight
 def forward(self, x):
     if self.train_x:
         xp = pmath.project(pmath.expmap0(self.xp, c=self.c), c=self.c)
         return pmath.logmap(xp, x, c=self.c)
     return pmath.logmap0(x, c=self.c)
Beispiel #10
0
    def forward(self, x):
        #self.tp.c.data = self.tp.c.data.clamp(min=c_limit)
        hyp_result = []
        klein_result = []
        for i in range(self.inst_num):
            med = x[:, i, :, :, :]
            x_size = x.size()
            med_input = med.view(x_size[0], x_size[2], x_size[3], x_size[4])
            med_out = self.lenet(med_input)
            hyp_med_out = self.tp(med_out)

            if torch.isnan(hyp_med_out).any():
                print('error: hyp_med_out')
                print(med_out.data.cpu())
                print(hyp_med_out.data.cpu())
                print(self.tp)
                sys.exit()

            hyp_med_out = torch.reshape(hyp_med_out,
                                        (hyp_med_out.shape[0], 1, -1))
            hyp_result.append(hyp_med_out)

            klein_med_out = pmath.p2k(hyp_med_out, c=self.c)
            klein_med_out = pmath.project(klein_med_out, c=self.c)

            if torch.isnan(klein_med_out).any():
                print('klein_med_out')

            klein_med_out = torch.reshape(klein_med_out,
                                          (klein_med_out.shape[0], 1, -1))
            klein_result.append(klein_med_out)

        hyp_img_embed = torch.cat(hyp_result, dim=1)
        klein_img_embed = torch.cat(klein_result, dim=1)

        weight_result = []
        lorenz_f_result = []
        for i in range(self.inst_num):
            hyp_img_embed_input_med = hyp_img_embed[:, i, :]
            hyp_img_embed_size = hyp_img_embed.size()
            hyp_img_embed_input = hyp_img_embed_input_med.view(
                hyp_img_embed_size[0], -1)
            weight = self.att_weight(hyp_img_embed_input)

            if torch.isnan(weight).any():
                print('weight')

            klein_img_embed_input_med = klein_img_embed[:, i, :]
            klein_img_embed_size = klein_img_embed.size()
            klein_img_embed_input = klein_img_embed_input_med.view(
                klein_img_embed_size[0], -1)

            lorenz_f = pmath.lorenz_factor(klein_img_embed_input,
                                           c=self.c,
                                           dim=1,
                                           keepdim=True)

            if torch.isnan(lorenz_f).any():
                print('lorenz_f')

            weight_result.append(weight)
            lorenz_f_result.append(lorenz_f)

        embed_weight = torch.cat(weight_result, 1)
        out_weight = embed_weight
        lamb = torch.cat(lorenz_f_result, 1)
        alpha = torch.nn.Softmax(dim=1)(embed_weight)

        alpha_lamb = alpha * lamb
        alpha_lamb_sum = torch.sum(alpha_lamb, dim=1)
        alpha_lamb_sum = alpha_lamb_sum.unsqueeze(dim=1)
        alpha_lamb_norm = alpha_lamb / alpha_lamb_sum
        alpha_lamb_norm = torch.reshape(alpha_lamb_norm,
                                        (alpha_lamb_norm.shape[0], 1, -1))
        rep = torch.bmm(alpha_lamb_norm, klein_img_embed)
        rep = torch.reshape(rep, (rep.shape[0], -1))

        rep = pmath.project(rep, c=self.c)

        if torch.isnan(rep).any():
            print('klein_med_out')

        rep = pmath.k2p(rep, c=self.c)
        msi = self.msi_pred(rep)

        if torch.isnan(msi).any():
            print('msi')

        return msi, out_weight
Beispiel #11
0
    def forward(self, x, c=None):
        if c is None:
            c = self.c

        # do proper normalization of euclidean data
        x = pmath.project(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())

        # BLOCK 1

        x = self.c1(x, c=c)
        # batch norm
        #         x = pmath.logmap0(x, c=c)
        #         x = self.b1(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # blocked relu and maxpool 2
        #         x = pmath.logmap0(x, c=c)
        #         x = nn.ReLU()(x)
        #         x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # separate relu and maxpool 2
        x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = nn.ReLU()(x)
        x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())

        #         xbrp = x
        #         print(f'norm after relu: {x.norm(dim=-1, keepdim=True, p=2)[0]}')
        x = pmath.project(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        #         print(f'norm after projection relu: {x.norm(dim=-1, keepdim=True, p=2)[0]}')
        #         print(f'diff: {xbrp[0]-x[0]}')
        #         print(f'diff sum: {sum(sum(sum(xbrp[0]-x[0])))}')
        #         print(f'x after relu project the same: {xbrp.equal(x)}')

        x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = nn.MaxPool2d(2)(x)
        x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())

        #         xbpp = x
        x = pmath.project(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        #         print(f'x after pool project the same: {xbpp.equal(x)}')

        #         # BLOCK 2

        #         x = self.c2(x, c=c)
        #         # batch norm
        # #         x = pmath.logmap0(x, c=c)
        # #         x = self.b2(x)
        # #         x = pmath.expmap0(x, c=c)
        # #         x = pmath.project(x, c=c)

        #         # blocked relu and maxpool 2
        # #         x = pmath.logmap0(x, c=c)
        # #         x = nn.ReLU()(x)
        # #         x = nn.MaxPool2d(2)(x)
        # #         x = pmath.expmap0(x, c=c)
        # #         x = pmath.project(x, c=c)

        #         # separate relu and maxpool 2
        #         x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = nn.ReLU()(x)
        #         x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())

        #         x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())

        #         # BLOCK 3

        #         x = self.c3(x, c=c)
        #         # batch norm
        # #         x = pmath.logmap0(x, c=c)
        # #         x = self.b3(x)
        # #         x = pmath.expmap0(x, c=c)
        # #         x = pmath.project(x, c=c)

        #         # blocked relu and maxpool 2
        # #         x = pmath.logmap0(x, c=c)
        # #         x = nn.ReLU()(x)
        # #         x = nn.MaxPool2d(2)(x)
        # #         x = pmath.expmap0(x, c=c)
        # #         x = pmath.project(x, c=c)

        #         # separate relu and maxpool 2
        #         x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = nn.ReLU()(x)
        #         x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())

        #         x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())

        #         # BLOCK 4

        #         x = self.c4(x, c=c)
        #         # batch norm
        # #         x = pmath.logmap0(x, c=c)
        # #         x = self.b4(x)
        # #         x = pmath.expmap0(x, c=c)
        # #         x = pmath.project(x, c=c)

        #         # blocked relu and maxpool 2
        # #         x = pmath.logmap0(x, c=c)
        # #         x = nn.ReLU()(x)
        # #         x = nn.MaxPool2d(2)(x)
        # #         x = pmath.expmap0(x, c=c)
        # #         x = pmath.project(x, c=c)

        #         # separate relu and maxpool 2
        #         x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = nn.ReLU()(x)
        #         x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())

        #         x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())
        #         x = pmath.project(x.view(x.size(0) * x.size(1), -1), c=c).view(x.size())

        # final pool
        x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = nn.MaxPool2d(5)(x)
        x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = pmath.project(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())

        # print(x.size()), currently N x 512 x 1 x 1

        # currently I believe this step may mess with the geometry
        # what would be a natural replacement? A: view as eucl vector, then do expmap to go back to hyperbolic space
        x = x.view(x.size(0), -1)
        x = pmath.expmap0(x, c=c)
        x = pmath.project(x, c=c)
        return x
Beispiel #12
0
    def forward(self, x, c=None):
        if c is None:
            c = self.c

        # BLOCK 1

        x = self.c1(x)
        # batch norm
        #         x = pmath.logmap0(x, c=c)
        #         x = self.b1(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # blocked relu and maxpool 2
        #         x = pmath.logmap0(x, c=c)
        #         x = nn.ReLU()(x)
        #         x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # separate relu and maxpool 2
        #         x = pmath.logmap0(x, c=c)
        x = nn.ReLU()(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        #         x = pmath.logmap0(x, c=c)
        x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # BLOCK 2

        x = self.c2(x)
        # batch norm
        #         x = pmath.logmap0(x, c=c)
        #         x = self.b2(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # blocked relu and maxpool 2
        #         x = pmath.logmap0(x, c=c)
        #         x = nn.ReLU()(x)
        #         x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # separate relu and maxpool 2
        #         x = pmath.logmap0(x, c=c)
        x = nn.ReLU()(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        #         x = pmath.logmap0(x, c=c)
        x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # BLOCK 3

        x = self.c3(x)
        # batch norm
        #         x = pmath.logmap0(x, c=c)
        #         x = self.b3(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # blocked relu and maxpool 2
        #         x = pmath.logmap0(x, c=c)
        #         x = nn.ReLU()(x)
        #         x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # separate relu and maxpool 2
        #         x = pmath.logmap0(x, c=c)
        x = nn.ReLU()(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        #         x = pmath.logmap0(x, c=c)
        x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # BLOCK 4, to hyperbolic

        x = self.e2p(x)

        x = self.c4(x, c=c)
        # batch norm
        #         x = pmath.logmap0(x, c=c)
        #         x = self.b4(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # blocked relu and maxpool 2
        #         x = pmath.logmap0(x, c=c)
        #         x = nn.ReLU()(x)
        #         x = nn.MaxPool2d(2)(x)
        #         x = pmath.expmap0(x, c=c)
        #         x = pmath.project(x, c=c)

        # separate relu and maxpool 2
        x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = nn.ReLU()(x)
        x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = pmath.project(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())

        x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = nn.MaxPool2d(2)(x)
        x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = pmath.project(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())

        # final pool
        x = pmath.logmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = nn.MaxPool2d(5)(x)
        x = pmath.expmap0(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())
        x = pmath.project(x.view(x.size(0) * x.size(1), -1),
                          c=c).view(x.size())

        # print(x.size()), currently N x 512 x 1 x 1

        # currently I believe this step may mess with the geometry
        # what would be a natural replacement? A: view as eucl vector, then do expmap to go back to hyperbolic space
        x = x.view(x.size(0), -1)
        x = pmath.expmap0(x, c=c)
        x = pmath.project(x, c=c)
        return x