示例#1
0
def gauss(args, mu, std):
    radius = torch.Tensor([args.radius]).to(args.dev)

    mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm)
    mu_h = exp_map_mu0(expand_proj_dims(mu), radius)

    p_z = HyperboloidWrappedNormal(radius, mu_h, std)

    # Map x, y coordinates on tangent space at origin to manifold (Lorentz model).
    x = np.arange(-5, 5, 0.1)
    y = np.arange(-5, 5, 0.1)
    x, y = np.meshgrid(x, y)
    x = torch.Tensor(x).view(-1, 1)
    y = torch.Tensor(y).view(-1, 1)
    twodim = torch.cat([x, y], dim=1)
    threedim = expand_proj_dims(twodim)
    clamped_threedim = clamp(threedim, min=-max_clamp_norm,
                             max=max_clamp_norm).to(args.dev)
    on_mani = exp_map_mu0(clamped_threedim, radius)

    # Calculate densities of x, y coords on Lorentz model.
    probs = p_z.log_prob(on_mani)
    probs = torch.exp(probs)

    # Calculate the poincare coordinates
    xy_poincare = lorentz_to_poincare(on_mani.squeeze(), radius)
    mu_p = lorentz_to_poincare(mu_h, radius)

    plot_density(xy_poincare, probs, args.radius, args.namestr, mu=mu_p)
    if args.flow != 'none':
        plot_flow(args, radius, args.flow, p_z, args.namestr, args.n_blocks)
示例#2
0
    def inverse(self, z_hyper, edge_index=None):
        z = inverse_exp_map_mu0(z_hyper, self.radius)
        z_mu0 = z[..., 1:]
        log_det_J, x = z_mu0.new_zeros(z_mu0.shape[0]), z_mu0
        log_det_J = logmap_logdet(z, self.radius)
        preclamp_norm_list = []
        for i in range(0, self.n_blocks):
            x_ = x * self.mask[i]
            if self.layer_type != 'Linear':
                s = self.s[i](x_, edge_index)
                t_out = self.t[i](x_, edge_index)
            else:
                s = self.s[i](x_)
                t_out = self.t[i](x_)
            t_proj = proj_vec(t_out, self.radius)
            t1, t_rest = t_proj[:, 0].unsqueeze(1), t_proj[:, 1:]
            t = self.create_masked_t((1 - self.mask[i]), t1, t_rest)
            # (1-b) \odot \tilde{x} \odot exp(s(b \odot \tilde{x}))
            x_pt_arg = expand_proj_dims((1 - self.mask[i]) * x * torch.exp(s))

            # (1-b) \odot \textnormal{PT}_{\textbf{o}\to t(b \odot \tilde{x})
            pt = parallel_transport_mu0(x_pt_arg, dst=t, radius=self.radius)
            preclamp_norm = pt.max()
            pt = clamp(pt, min=-max_clamp_norm, max=max_clamp_norm)
            if pt.max() == max_clamp_norm:
                preclamp_norm_list.append(preclamp_norm)
            x_t = exp_map(x=pt, at_point=t, radius=self.radius)
            log_det_J += _logdet(pt, self.radius, subdim=(self.mask[i]).sum())
            preclamp_norm = x_t.max()
            x_t = clamp(x_t, min=-max_clamp_norm, max=max_clamp_norm)
            if x_t.max() == max_clamp_norm:
                preclamp_norm_list.append(preclamp_norm)

            #\log_{\textbf{o}}(\textnormal{exp}_{t()}(\textnormal{PT}_{\textbf{o}\to t()))
            x_0_full = inverse_exp_map_mu0(x_t, self.radius)
            x_0 = x_0_full[..., 1:]
            log_det_J += logmap_logdet(x_0_full,
                                       self.radius,
                                       subdim=(self.mask[i]).sum())
            x = x_ + (1 - self.mask[i]) * x_0
            log_det_J += ((1 - self.mask[i]) * s).sum(dim=1)  # log det dx/du

            preclamp_norm = x.max()
            x = clamp(x, min=-max_clamp_norm, max=max_clamp_norm)
            if x.max() == max_clamp_norm:
                preclamp_norm_list.append(preclamp_norm)

        x_mu0 = expand_proj_dims(x)
        # Project back to Manifold
        x = exp_map_mu0(x_mu0, self.radius)
        log_det_J += _logdet(x_mu0, self.radius)

        self.preclamp_norm = torch.Tensor([
            sum(preclamp_norm_list) / len(preclamp_norm_list)
        ]) if preclamp_norm_list else self.preclamp_norm
        return x, log_det_J
示例#3
0
 def forward(self, x_hyper):
     x = inverse_exp_map_mu0(x_hyper, self.radius)
     x_mu0 = x[..., 1:]
     log_det_J, z = x.new_zeros(x_mu0.shape[0]), x_mu0
     log_det_J = -1 * logmap_logdet(x, self.radius)
     for i in reversed(range(0, self.n_blocks)):
         if i > 0:
             # Project between Flow Layers
             z_proj_mu0 = inverse_exp_map_mu0(z, self.radius)
             z = z_proj_mu0[..., 1:]
             log_det_J -= logmap_logdet(z_proj_mu0, self.radius)
         z_ = self.mask[i] * z
         if self.layer_type != 'Linear':
             s = self.s[i](z_, edge_index)
             t = self.t[i](z_, edge_index)
         else:
             s = self.s[i](z_)
             t = self.t[i](z_)
         z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
         log_det_J -= ((1 - self.mask[i]) * s).sum(dim=1)
         z_mu0 = expand_proj_dims(z)
         # Project back to Manifold
         z = exp_map_mu0(z_mu0, self.radius)
         log_det_J -= _logdet(z_mu0, self.radius)
     return z, log_det_J
示例#4
0
 def inverse(self, z_hyper):
     z = inverse_exp_map_mu0(z_hyper, self.radius)
     z_mu0 = z[..., 1:]
     log_det_J, x = z_mu0.new_zeros(z_mu0.shape[0]), z_mu0
     log_det_J = logmap_logdet(z, self.radius)
     for i in range(0, self.n_blocks):
         if i > 0:
             # Project between Flow Layers
             x_proj_mu0 = inverse_exp_map_mu0(x, self.radius)
             x = x_proj_mu0[..., 1:]
             log_det_J += logmap_logdet(x_proj_mu0, self.radius)
         x_ = x * self.mask[i]
         if self.layer_type != 'Linear':
             s = self.s[i](x_, edge_index)
             t = self.t[i](x_, edge_index)
         else:
             s = self.s[i](x_)
             t = self.t[i](x_)
         x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
         self.preclamp_norm = x.max()
         x = clamp(x, min=-max_clamp_norm, max=max_clamp_norm)
         log_det_J += ((1 - self.mask[i]) * s).sum(dim=1)  # log det dx/du
         x_mu0 = expand_proj_dims(x)
         # Project back to Manifold
         x = exp_map_mu0(x_mu0, self.radius)
         log_det_J += _logdet(x_mu0, self.radius)
     return x, log_det_J
示例#5
0
def some_density(args):
    radius = torch.Tensor([args.radius]).cuda()
    n_pts = 100

    f1 = lambda z: torch.sin(6 * math.pi * z[:, 0] / 4)
    f2 = lambda z: 3 * torch.exp(-0.5 * ((z[:, 0] - 1) / 0.6)**2)
    f3 = lambda z: 3 * torch.sigmoid((z[:, 0] - 1) / 0.3)
    xx, yy, zz = setup_grid(5, n_pts)
    base_prob_dist = -f1(zz)

    # Map x, y coordinates on tangent space at origin to manifold (Lorentz model).
    twodim = zz
    threedim = expand_proj_dims(twodim).cuda()
    clamped_threedim = clamp(threedim, min=-max_clamp_norm,
                             max=max_clamp_norm).cuda()
    on_mani = exp_map_mu0(clamped_threedim, radius)

    # Calculate densities of x, y coords on Lorentz model.
    log_det = _logdet(clamped_threedim, radius)
    log_probs = base_prob_dist - log_det
    probs = torch.exp(log_probs)

    # Calculate the poincare coordinates
    xy_poincare = lorentz_to_poincare(on_mani.squeeze(), radius)

    plot_density(xy_poincare, probs, radius, args.namestr)
    if args.flow != 'none':
        plot_flow(args, radius, args.flow, f1, args.namestr)
示例#6
0
    def forward(self, x_hyper, edge_index=None):
        x = inverse_exp_map_mu0(x_hyper, self.radius)
        x_mu0 = x[..., 1:]
        log_det_J, z = x.new_zeros(x_mu0.shape[0]), x_mu0
        log_det_J = -1 * logmap_logdet(x, self.radius)
        for i in reversed(range(0, self.n_blocks)):
            z_ = self.mask[i] * z
            if self.layer_type != 'Linear':
                s = self.s[i](z_, edge_index)
                t_out = self.t[i](z_, edge_index)
            else:
                s = self.s[i](z_)
                t_out = self.t[i](z_)
            t_proj = proj_vec(t_out, self.radius)

            t1, t_rest = t_proj[:, 0].unsqueeze(1), t_proj[:, 1:]
            t = self.create_masked_t((1 - self.mask[i]), t1, t_rest)

            z_2 = expand_proj_dims((1 - self.mask[i]) * z)
            z_2 = clamp(z_2, min=-max_clamp_norm, max=max_clamp_norm)
            z_exp_2 = exp_map_mu0(z_2, self.radius)
            log_det_J -= _logdet(z_2, self.radius, subdim=(self.mask[i]).sum())

            z_exp_2 = clamp(z_exp_2, min=-max_clamp_norm, max=max_clamp_norm)
            z_inv_pt_arg = inverse_exp_map(x=z_exp_2,
                                           at_point=t,
                                           radius=self.radius)
            log_det_J -= logmap_logdet(z_inv_pt_arg,
                                       self.radius,
                                       subdim=(self.mask[i]).sum())

            z_inv_pt_arg = clamp(z_inv_pt_arg,
                                 min=-max_clamp_norm,
                                 max=max_clamp_norm)
            pt = inverse_parallel_transport_mu0(z_inv_pt_arg,
                                                src=t,
                                                radius=self.radius)
            pt = pt[..., 1:]

            z = (1 - self.mask[i]) * pt * torch.exp(-s) + z_
            log_det_J -= ((1 - self.mask[i]) * s).sum(dim=1)

        z_mu0 = expand_proj_dims(z)
        z = exp_map_mu0(z_mu0, self.radius)
        log_det_J -= _logdet(z_mu0, self.radius)
        return z, log_det_J
示例#7
0
 def sample_projection_mu0(
         self, x: Tensor, at_point: Tensor,
         radius: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
     x_expanded = expand_proj_dims(x)
     pt = parallel_transport_mu0(x_expanded, dst=at_point, radius=radius)
     pt = clamp(pt, min=-max_clamp_norm, max=max_clamp_norm)
     x_proj = exp_map(pt, at_point=at_point, radius=radius)
     return x_proj, (pt, x)
示例#8
0
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        mu = self.conv_mu(x, edge_index)
        logvar = self.conv_logvar(x, edge_index)
        mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm)
        assert torch.isfinite(mu).all()
        assert torch.isfinite(logvar).all()
        mu_h = exp_map_mu0(expand_proj_dims(mu), self.radius)
        assert torch.isfinite(mu_h).all()

        # +eps prevents collapse
        std = F.softplus(logvar) + 1e-5
        assert torch.isfinite(std).all()
        self.q_z, self.p_z = self.reparametrize(mu_h, std)
        z_0, data = self.q_z.rsample_with_parts()
        return z_0, mu_h, std, data
示例#9
0
    def bottleneck(self, h):
        mu, logvar = self.fc_mean(h), self.fc_logvar(h)
        mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm)
        assert torch.isfinite(mu).all()
        assert torch.isfinite(logvar).all()
        mu_h = exp_map_mu0(expand_proj_dims(mu), self.radius)
        assert torch.isfinite(mu_h).all()

        # +eps prevents collapse
        std = F.softplus(logvar) + 1e-5
        assert torch.isfinite(std).all()
        q_z, p_z = self.reparametrize(mu_h, std)
        self.q_z = q_z
        self.p_z = p_z
        z, data = q_z.rsample_with_parts()
        self.data = data
        return z, mu_h, std
示例#10
0
def mixture(args):
    radius = torch.Tensor([args.radius]).to(args.dev)
    samples = sample_2d_data(args.dataset, 100000).to(args.dev)
    samples = clamp(samples, min=-max_clamp_norm, max=max_clamp_norm)
    xi = samples[:, 0].detach().cpu().numpy()
    yi = samples[:, 1].detach().cpu().numpy()
    samples_h = exp_map_mu0(expand_proj_dims(samples), radius)
    # Calculate the poincare coordinates
    xy_poincare = lorentz_to_poincare(samples_h.squeeze(), radius)

    fig = plt.figure()
    ax = fig.add_subplot(111)
    x = xy_poincare[:, 0].view(-1, 100).detach().cpu()
    y = xy_poincare[:, 1].view(-1, 100).detach().cpu()
    p_z = None

    # Define points within circle
    range_lim = 5
    ax.hist2d(xy_poincare[:, 0].detach().cpu().numpy(),
              xy_poincare[:, 1].detach().cpu().numpy(),
              range=[[-range_lim, range_lim], [-range_lim, range_lim]],
              bins=5000,
              cmap='magma')
    # ax.contourf(x, y, z, 100, antialiased=False, cmap='magma')

    ax.axis('off')
    # Makes the circle look like a circle
    ax.axis('equal')
    ax.set_xlim(-args.axis_lim, args.axis_lim)
    ax.set_ylim(-args.axis_lim, args.axis_lim)

    # Save the full figure...
    fig.savefig('install/{}.png'.format(args.namestr))
    print("saved to install/{}.png".format(args.namestr))

    if args.flow != 'none':
        plot_flow(args,
                  radius,
                  args.flow,
                  p_z,
                  args.namestr,
                  n_blocks=args.n_blocks,
                  samples=samples_h)
示例#11
0
    def MC_log_likelihood(self, x):
        """
        :param x: Mini-batch of inputs.
        :param n: Number of MC samples
        :return: Monte Carlo estimate of log-likelihood.
        """
        n = self.K
        sample_shape = torch.Size([n])
        batch_size = x.shape[0]
        prob_shape = torch.Size([n, batch_size])

        x_encoded = self.encoder(x)
        mu, logvar = self.fc_mean(x_encoded), self.fc_logvar(x_encoded)
        mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm)
        mu_h = exp_map_mu0(expand_proj_dims(mu), self.radius)

        # +eps prevents collapse
        std = F.softplus(logvar) + 1e-5
        q_z, p_z = self.reparametrize(mu_h, std)
        log_p_z = torch.zeros(prob_shape, device=x.device)
        log_q_z_x = torch.zeros(prob_shape, device=x.device)

        # Numerically more stable.
        z, log_q_z_x, log_p_z = self.rsample_log_probs(sample_shape, q_z, p_z)
        z = inverse_exp_map_mu0(z, self.radius)
        log_q_z_x = log_q_z_x - logmap_logdet(z, self.radius)

        x_mb_ = self.decode(z)
        x_orig = x.repeat((n, 1, 1))
        log_p_x_z = -self.recon_loss(x_mb_, x_orig).sum(dim=-1)

        assert log_p_x_z.shape == log_p_z.shape
        assert log_q_z_x.shape == log_p_z.shape
        joint = (log_p_x_z + log_p_z - log_q_z_x)
        log_p_x = joint.logsumexp(dim=0) - np.log(n)

        assert log_q_z_x.shape == log_p_z.shape
        mi = (log_q_z_x - log_p_z).logsumexp(dim=0) - np.log(n)

        return log_p_x, mi
示例#12
0
def plot_flow(args, radius, flow, target, namestr, n_blocks=2, samples=None):
    fig = plt.figure()
    ax = fig.add_subplot(555)
    # Map x, y coordinates on tangent space at origin to manifold (Lorentz model).
    x = torch.linspace(-5, 5, 100)
    xx, yy = torch.meshgrid((x, x))
    # x = np.arange(-5, 5, 0.1)
    # y = np.arange(-5, 5, 0.1)
    # x, y = np.meshgrid(x, y)
    # x = torch.Tensor(x).view(-1, 1)
    # y = torch.Tensor(y).view(-1, 1)
    twodim = torch.stack((xx.flatten(), yy.flatten()), dim=1)
    # twodim = torch.cat([x, y], dim=1)
    threedim = expand_proj_dims(twodim)
    clamped_threedim = clamp(threedim, min=-max_clamp_norm,
                             max=max_clamp_norm).to(args.dev)
    on_mani = exp_map_mu0(clamped_threedim, radius).cuda()
    # flow_model = train_potential_flow(flow, radius, target)
    if samples is not None:
        flow_model = train_flow_density(args, flow, n_blocks, radius, samples,
                                        clamped_threedim, on_mani)
    else:
        flow_model = train_flow(args, flow, radius, target, clamped_threedim,
                                on_mani)

    # Calculate densities of x, y coords on Lorentz model.
    flow_model.base_dist_mean = torch.zeros_like(on_mani).cuda()
    flow_model.base_dist_var = torch.ones(on_mani.shape[0], 2).cuda()
    probs = flow_model.log_prob(on_mani)
    probs += logmap_logdet(clamped_threedim.cuda(), radius)
    probs = torch.exp(probs)

    on_mani_conv = on_mani.detach().cpu()

    # Calculate the poincare coordinates
    xy_poincare = lorentz_to_poincare(on_mani.squeeze(), radius)
    plot_density(xy_poincare, probs, flow_model.radius, namestr, flow=flow)