def gauss(args, mu, std): radius = torch.Tensor([args.radius]).to(args.dev) mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm) mu_h = exp_map_mu0(expand_proj_dims(mu), radius) p_z = HyperboloidWrappedNormal(radius, mu_h, std) # Map x, y coordinates on tangent space at origin to manifold (Lorentz model). x = np.arange(-5, 5, 0.1) y = np.arange(-5, 5, 0.1) x, y = np.meshgrid(x, y) x = torch.Tensor(x).view(-1, 1) y = torch.Tensor(y).view(-1, 1) twodim = torch.cat([x, y], dim=1) threedim = expand_proj_dims(twodim) clamped_threedim = clamp(threedim, min=-max_clamp_norm, max=max_clamp_norm).to(args.dev) on_mani = exp_map_mu0(clamped_threedim, radius) # Calculate densities of x, y coords on Lorentz model. probs = p_z.log_prob(on_mani) probs = torch.exp(probs) # Calculate the poincare coordinates xy_poincare = lorentz_to_poincare(on_mani.squeeze(), radius) mu_p = lorentz_to_poincare(mu_h, radius) plot_density(xy_poincare, probs, args.radius, args.namestr, mu=mu_p) if args.flow != 'none': plot_flow(args, radius, args.flow, p_z, args.namestr, args.n_blocks)
def forward(self, x_hyper): x = inverse_exp_map_mu0(x_hyper, self.radius) x_mu0 = x[..., 1:] log_det_J, z = x.new_zeros(x_mu0.shape[0]), x_mu0 log_det_J = -1 * logmap_logdet(x, self.radius) for i in reversed(range(0, self.n_blocks)): if i > 0: # Project between Flow Layers z_proj_mu0 = inverse_exp_map_mu0(z, self.radius) z = z_proj_mu0[..., 1:] log_det_J -= logmap_logdet(z_proj_mu0, self.radius) z_ = self.mask[i] * z if self.layer_type != 'Linear': s = self.s[i](z_, edge_index) t = self.t[i](z_, edge_index) else: s = self.s[i](z_) t = self.t[i](z_) z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_ log_det_J -= ((1 - self.mask[i]) * s).sum(dim=1) z_mu0 = expand_proj_dims(z) # Project back to Manifold z = exp_map_mu0(z_mu0, self.radius) log_det_J -= _logdet(z_mu0, self.radius) return z, log_det_J
def inverse(self, z_hyper): z = inverse_exp_map_mu0(z_hyper, self.radius) z_mu0 = z[..., 1:] log_det_J, x = z_mu0.new_zeros(z_mu0.shape[0]), z_mu0 log_det_J = logmap_logdet(z, self.radius) for i in range(0, self.n_blocks): if i > 0: # Project between Flow Layers x_proj_mu0 = inverse_exp_map_mu0(x, self.radius) x = x_proj_mu0[..., 1:] log_det_J += logmap_logdet(x_proj_mu0, self.radius) x_ = x * self.mask[i] if self.layer_type != 'Linear': s = self.s[i](x_, edge_index) t = self.t[i](x_, edge_index) else: s = self.s[i](x_) t = self.t[i](x_) x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t) self.preclamp_norm = x.max() x = clamp(x, min=-max_clamp_norm, max=max_clamp_norm) log_det_J += ((1 - self.mask[i]) * s).sum(dim=1) # log det dx/du x_mu0 = expand_proj_dims(x) # Project back to Manifold x = exp_map_mu0(x_mu0, self.radius) log_det_J += _logdet(x_mu0, self.radius) return x, log_det_J
def some_density(args): radius = torch.Tensor([args.radius]).cuda() n_pts = 100 f1 = lambda z: torch.sin(6 * math.pi * z[:, 0] / 4) f2 = lambda z: 3 * torch.exp(-0.5 * ((z[:, 0] - 1) / 0.6)**2) f3 = lambda z: 3 * torch.sigmoid((z[:, 0] - 1) / 0.3) xx, yy, zz = setup_grid(5, n_pts) base_prob_dist = -f1(zz) # Map x, y coordinates on tangent space at origin to manifold (Lorentz model). twodim = zz threedim = expand_proj_dims(twodim).cuda() clamped_threedim = clamp(threedim, min=-max_clamp_norm, max=max_clamp_norm).cuda() on_mani = exp_map_mu0(clamped_threedim, radius) # Calculate densities of x, y coords on Lorentz model. log_det = _logdet(clamped_threedim, radius) log_probs = base_prob_dist - log_det probs = torch.exp(log_probs) # Calculate the poincare coordinates xy_poincare = lorentz_to_poincare(on_mani.squeeze(), radius) plot_density(xy_poincare, probs, radius, args.namestr) if args.flow != 'none': plot_flow(args, radius, args.flow, f1, args.namestr)
def forward(self, x_hyper, edge_index=None): x = inverse_exp_map_mu0(x_hyper, self.radius) x_mu0 = x[..., 1:] log_det_J, z = x.new_zeros(x_mu0.shape[0]), x_mu0 log_det_J = -1 * logmap_logdet(x, self.radius) for i in reversed(range(0, self.n_blocks)): z_ = self.mask[i] * z if self.layer_type != 'Linear': s = self.s[i](z_, edge_index) t_out = self.t[i](z_, edge_index) else: s = self.s[i](z_) t_out = self.t[i](z_) t_proj = proj_vec(t_out, self.radius) t1, t_rest = t_proj[:, 0].unsqueeze(1), t_proj[:, 1:] t = self.create_masked_t((1 - self.mask[i]), t1, t_rest) z_2 = expand_proj_dims((1 - self.mask[i]) * z) z_2 = clamp(z_2, min=-max_clamp_norm, max=max_clamp_norm) z_exp_2 = exp_map_mu0(z_2, self.radius) log_det_J -= _logdet(z_2, self.radius, subdim=(self.mask[i]).sum()) z_exp_2 = clamp(z_exp_2, min=-max_clamp_norm, max=max_clamp_norm) z_inv_pt_arg = inverse_exp_map(x=z_exp_2, at_point=t, radius=self.radius) log_det_J -= logmap_logdet(z_inv_pt_arg, self.radius, subdim=(self.mask[i]).sum()) z_inv_pt_arg = clamp(z_inv_pt_arg, min=-max_clamp_norm, max=max_clamp_norm) pt = inverse_parallel_transport_mu0(z_inv_pt_arg, src=t, radius=self.radius) pt = pt[..., 1:] z = (1 - self.mask[i]) * pt * torch.exp(-s) + z_ log_det_J -= ((1 - self.mask[i]) * s).sum(dim=1) z_mu0 = expand_proj_dims(z) z = exp_map_mu0(z_mu0, self.radius) log_det_J -= _logdet(z_mu0, self.radius) return z, log_det_J
def inverse(self, z_hyper, edge_index=None): z = inverse_exp_map_mu0(z_hyper, self.radius) z_mu0 = z[..., 1:] log_det_J, x = z_mu0.new_zeros(z_mu0.shape[0]), z_mu0 log_det_J = logmap_logdet(z, self.radius) preclamp_norm_list = [] for i in range(0, self.n_blocks): x_ = x * self.mask[i] if self.layer_type != 'Linear': s = self.s[i](x_, edge_index) t_out = self.t[i](x_, edge_index) else: s = self.s[i](x_) t_out = self.t[i](x_) t_proj = proj_vec(t_out, self.radius) t1, t_rest = t_proj[:, 0].unsqueeze(1), t_proj[:, 1:] t = self.create_masked_t((1 - self.mask[i]), t1, t_rest) # (1-b) \odot \tilde{x} \odot exp(s(b \odot \tilde{x})) x_pt_arg = expand_proj_dims((1 - self.mask[i]) * x * torch.exp(s)) # (1-b) \odot \textnormal{PT}_{\textbf{o}\to t(b \odot \tilde{x}) pt = parallel_transport_mu0(x_pt_arg, dst=t, radius=self.radius) preclamp_norm = pt.max() pt = clamp(pt, min=-max_clamp_norm, max=max_clamp_norm) if pt.max() == max_clamp_norm: preclamp_norm_list.append(preclamp_norm) x_t = exp_map(x=pt, at_point=t, radius=self.radius) log_det_J += _logdet(pt, self.radius, subdim=(self.mask[i]).sum()) preclamp_norm = x_t.max() x_t = clamp(x_t, min=-max_clamp_norm, max=max_clamp_norm) if x_t.max() == max_clamp_norm: preclamp_norm_list.append(preclamp_norm) #\log_{\textbf{o}}(\textnormal{exp}_{t()}(\textnormal{PT}_{\textbf{o}\to t())) x_0_full = inverse_exp_map_mu0(x_t, self.radius) x_0 = x_0_full[..., 1:] log_det_J += logmap_logdet(x_0_full, self.radius, subdim=(self.mask[i]).sum()) x = x_ + (1 - self.mask[i]) * x_0 log_det_J += ((1 - self.mask[i]) * s).sum(dim=1) # log det dx/du preclamp_norm = x.max() x = clamp(x, min=-max_clamp_norm, max=max_clamp_norm) if x.max() == max_clamp_norm: preclamp_norm_list.append(preclamp_norm) x_mu0 = expand_proj_dims(x) # Project back to Manifold x = exp_map_mu0(x_mu0, self.radius) log_det_J += _logdet(x_mu0, self.radius) self.preclamp_norm = torch.Tensor([ sum(preclamp_norm_list) / len(preclamp_norm_list) ]) if preclamp_norm_list else self.preclamp_norm return x, log_det_J
def forward(self, x): x_tangent_mu0 = inverse_exp_map_mu0(x, self.radius) output = exp_map_mu0(x_tangent_mu0.matmul(self.weight.t()), self.radius) if self.use_bias: output = parallel_transport_mu0(self.bias, output, self.radius) output = exp_map(output, x) ret = output h = self.hyper_act(ret) return h
def hyper_act(self, x): ''' Op: \sigma(x) Input: x: A Hyperbolic Vector in H^n Output: sigma_x_mu0: A Hyperbolic Vector in H^n after \sigma(x) ''' x_tangent_mu0 = inverse_exp_map_mu0(x, self.radius) sigma_x = F.relu(x_tangent_mu0) sigma_x_mu0 = exp_map_mu0(sigma_x, self.radius) return sigma_x_mu0
def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) mu = self.conv_mu(x, edge_index) logvar = self.conv_logvar(x, edge_index) mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm) assert torch.isfinite(mu).all() assert torch.isfinite(logvar).all() mu_h = exp_map_mu0(expand_proj_dims(mu), self.radius) assert torch.isfinite(mu_h).all() # +eps prevents collapse std = F.softplus(logvar) + 1e-5 assert torch.isfinite(std).all() self.q_z, self.p_z = self.reparametrize(mu_h, std) z_0, data = self.q_z.rsample_with_parts() return z_0, mu_h, std, data
def bottleneck(self, h): mu, logvar = self.fc_mean(h), self.fc_logvar(h) mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm) assert torch.isfinite(mu).all() assert torch.isfinite(logvar).all() mu_h = exp_map_mu0(expand_proj_dims(mu), self.radius) assert torch.isfinite(mu_h).all() # +eps prevents collapse std = F.softplus(logvar) + 1e-5 assert torch.isfinite(std).all() q_z, p_z = self.reparametrize(mu_h, std) self.q_z = q_z self.p_z = p_z z, data = q_z.rsample_with_parts() self.data = data return z, mu_h, std
def mixture(args): radius = torch.Tensor([args.radius]).to(args.dev) samples = sample_2d_data(args.dataset, 100000).to(args.dev) samples = clamp(samples, min=-max_clamp_norm, max=max_clamp_norm) xi = samples[:, 0].detach().cpu().numpy() yi = samples[:, 1].detach().cpu().numpy() samples_h = exp_map_mu0(expand_proj_dims(samples), radius) # Calculate the poincare coordinates xy_poincare = lorentz_to_poincare(samples_h.squeeze(), radius) fig = plt.figure() ax = fig.add_subplot(111) x = xy_poincare[:, 0].view(-1, 100).detach().cpu() y = xy_poincare[:, 1].view(-1, 100).detach().cpu() p_z = None # Define points within circle range_lim = 5 ax.hist2d(xy_poincare[:, 0].detach().cpu().numpy(), xy_poincare[:, 1].detach().cpu().numpy(), range=[[-range_lim, range_lim], [-range_lim, range_lim]], bins=5000, cmap='magma') # ax.contourf(x, y, z, 100, antialiased=False, cmap='magma') ax.axis('off') # Makes the circle look like a circle ax.axis('equal') ax.set_xlim(-args.axis_lim, args.axis_lim) ax.set_ylim(-args.axis_lim, args.axis_lim) # Save the full figure... fig.savefig('install/{}.png'.format(args.namestr)) print("saved to install/{}.png".format(args.namestr)) if args.flow != 'none': plot_flow(args, radius, args.flow, p_z, args.namestr, n_blocks=args.n_blocks, samples=samples_h)
def MC_log_likelihood(self, x): """ :param x: Mini-batch of inputs. :param n: Number of MC samples :return: Monte Carlo estimate of log-likelihood. """ n = self.K sample_shape = torch.Size([n]) batch_size = x.shape[0] prob_shape = torch.Size([n, batch_size]) x_encoded = self.encoder(x) mu, logvar = self.fc_mean(x_encoded), self.fc_logvar(x_encoded) mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm) mu_h = exp_map_mu0(expand_proj_dims(mu), self.radius) # +eps prevents collapse std = F.softplus(logvar) + 1e-5 q_z, p_z = self.reparametrize(mu_h, std) log_p_z = torch.zeros(prob_shape, device=x.device) log_q_z_x = torch.zeros(prob_shape, device=x.device) # Numerically more stable. z, log_q_z_x, log_p_z = self.rsample_log_probs(sample_shape, q_z, p_z) z = inverse_exp_map_mu0(z, self.radius) log_q_z_x = log_q_z_x - logmap_logdet(z, self.radius) x_mb_ = self.decode(z) x_orig = x.repeat((n, 1, 1)) log_p_x_z = -self.recon_loss(x_mb_, x_orig).sum(dim=-1) assert log_p_x_z.shape == log_p_z.shape assert log_q_z_x.shape == log_p_z.shape joint = (log_p_x_z + log_p_z - log_q_z_x) log_p_x = joint.logsumexp(dim=0) - np.log(n) assert log_q_z_x.shape == log_p_z.shape mi = (log_q_z_x - log_p_z).logsumexp(dim=0) - np.log(n) return log_p_x, mi
def plot_flow(args, radius, flow, target, namestr, n_blocks=2, samples=None): fig = plt.figure() ax = fig.add_subplot(555) # Map x, y coordinates on tangent space at origin to manifold (Lorentz model). x = torch.linspace(-5, 5, 100) xx, yy = torch.meshgrid((x, x)) # x = np.arange(-5, 5, 0.1) # y = np.arange(-5, 5, 0.1) # x, y = np.meshgrid(x, y) # x = torch.Tensor(x).view(-1, 1) # y = torch.Tensor(y).view(-1, 1) twodim = torch.stack((xx.flatten(), yy.flatten()), dim=1) # twodim = torch.cat([x, y], dim=1) threedim = expand_proj_dims(twodim) clamped_threedim = clamp(threedim, min=-max_clamp_norm, max=max_clamp_norm).to(args.dev) on_mani = exp_map_mu0(clamped_threedim, radius).cuda() # flow_model = train_potential_flow(flow, radius, target) if samples is not None: flow_model = train_flow_density(args, flow, n_blocks, radius, samples, clamped_threedim, on_mani) else: flow_model = train_flow(args, flow, radius, target, clamped_threedim, on_mani) # Calculate densities of x, y coords on Lorentz model. flow_model.base_dist_mean = torch.zeros_like(on_mani).cuda() flow_model.base_dist_var = torch.ones(on_mani.shape[0], 2).cuda() probs = flow_model.log_prob(on_mani) probs += logmap_logdet(clamped_threedim.cuda(), radius) probs = torch.exp(probs) on_mani_conv = on_mani.detach().cpu() # Calculate the poincare coordinates xy_poincare = lorentz_to_poincare(on_mani.squeeze(), radius) plot_density(xy_poincare, probs, flow_model.radius, namestr, flow=flow)