Esempio n. 1
0
    def forward(self, input):
        x, x_e = input[0]
        adj = input[1]
        '''Euclidean probs'''
        probs_e, _ = self.cls_e((x_e, adj))
        '''Hyper probs'''
        x = self.cls_h((x, adj))
        x = x.unsqueeze(-2)
        distance = pmath.dist2plane(x=x,
                                    p=self.point,
                                    a=self.tangent,
                                    c=self.ball.c,
                                    signed=True)
        probs_h = distance * self.scale.exp()
        '''Prob. Assembling'''
        w_h = torch.sigmoid(
            self.w_h(self.manifold.logmap0(x.squeeze(), self.c)))
        w_h = F.dropout(w_h, p=self.drop_h, training=self.training)
        w_e = torch.sigmoid(self.w_e(probs_e))
        w_e = F.dropout(w_e, p=self.drop_e, training=self.training)

        w = torch.cat([w_h.view(-1, 1), w_e.view(-1, 1)], dim=-1)
        w = F.normalize(w, p=1, dim=-1)
        probs = w[-1, 0] * probs_h + w[-1, 1] * probs_e

        return super(DualDecoder, self).forward(probs)
Esempio n. 2
0
 def forward(self, input):
     input = input.unsqueeze(-2)
     distance = pmath.dist2plane(x=input,
                                 p=self.point,
                                 a=self.tangent,
                                 c=self.ball.c,
                                 signed=True)
     return distance * self.scale.exp()
Esempio n. 3
0
def hyperplane_dist(x, a, b, c, keepdim=False):
    return pmath.dist2plane(x, b, a, keepdim=keepdim, c=c)
Esempio n. 4
0
import seaborn as sns
from matplotlib import rcParams

rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
rcParams["text.usetex"] = True

sns.set_style("white")
radius = 1
coords = np.linspace(-radius, radius, 100)
x = torch.tensor([-0.75, 0])
v = torch.tensor([0.1 / 3, -1 / 3])
xx, yy = np.meshgrid(coords, coords)
dist2 = xx**2 + yy**2
mask = dist2 <= radius**2
grid = np.stack([xx, yy], axis=-1)
dists = pmath.dist2plane(torch.from_numpy(grid).float(), x, v)
dists[(~mask).nonzero()] = np.nan
circle = plt.Circle((0, 0), 1, fill=False, color="b")
plt.gca().add_artist(circle)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)

plt.gca().set_aspect("equal")
plt.contourf(grid[..., 0],
             grid[..., 1],
             dists.log().numpy(),
             levels=100,
             cmap="inferno")
plt.colorbar()
plt.scatter(*x, color="g")
plt.arrow(*x, *v, color="g", width=0.01)