Exemplo n.º 1
0
 def inverse(self, z_hyper):
     z = inverse_exp_map_mu0(z_hyper, self.radius)
     z_mu0 = z[..., 1:]
     log_det_J, x = z_mu0.new_zeros(z_mu0.shape[0]), z_mu0
     log_det_J = logmap_logdet(z, self.radius)
     for i in range(0, self.n_blocks):
         if i > 0:
             # Project between Flow Layers
             x_proj_mu0 = inverse_exp_map_mu0(x, self.radius)
             x = x_proj_mu0[..., 1:]
             log_det_J += logmap_logdet(x_proj_mu0, self.radius)
         x_ = x * self.mask[i]
         if self.layer_type != 'Linear':
             s = self.s[i](x_, edge_index)
             t = self.t[i](x_, edge_index)
         else:
             s = self.s[i](x_)
             t = self.t[i](x_)
         x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
         self.preclamp_norm = x.max()
         x = clamp(x, min=-max_clamp_norm, max=max_clamp_norm)
         log_det_J += ((1 - self.mask[i]) * s).sum(dim=1)  # log det dx/du
         x_mu0 = expand_proj_dims(x)
         # Project back to Manifold
         x = exp_map_mu0(x_mu0, self.radius)
         log_det_J += _logdet(x_mu0, self.radius)
     return x, log_det_J
Exemplo n.º 2
0
 def forward(self, x_hyper):
     x = inverse_exp_map_mu0(x_hyper, self.radius)
     x_mu0 = x[..., 1:]
     log_det_J, z = x.new_zeros(x_mu0.shape[0]), x_mu0
     log_det_J = -1 * logmap_logdet(x, self.radius)
     for i in reversed(range(0, self.n_blocks)):
         if i > 0:
             # Project between Flow Layers
             z_proj_mu0 = inverse_exp_map_mu0(z, self.radius)
             z = z_proj_mu0[..., 1:]
             log_det_J -= logmap_logdet(z_proj_mu0, self.radius)
         z_ = self.mask[i] * z
         if self.layer_type != 'Linear':
             s = self.s[i](z_, edge_index)
             t = self.t[i](z_, edge_index)
         else:
             s = self.s[i](z_)
             t = self.t[i](z_)
         z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
         log_det_J -= ((1 - self.mask[i]) * s).sum(dim=1)
         z_mu0 = expand_proj_dims(z)
         # Project back to Manifold
         z = exp_map_mu0(z_mu0, self.radius)
         log_det_J -= _logdet(z_mu0, self.radius)
     return z, log_det_J
Exemplo n.º 3
0
    def inverse(self, z_hyper, edge_index=None):
        z = inverse_exp_map_mu0(z_hyper, self.radius)
        z_mu0 = z[..., 1:]
        log_det_J, x = z_mu0.new_zeros(z_mu0.shape[0]), z_mu0
        log_det_J = logmap_logdet(z, self.radius)
        preclamp_norm_list = []
        for i in range(0, self.n_blocks):
            x_ = x * self.mask[i]
            if self.layer_type != 'Linear':
                s = self.s[i](x_, edge_index)
                t_out = self.t[i](x_, edge_index)
            else:
                s = self.s[i](x_)
                t_out = self.t[i](x_)
            t_proj = proj_vec(t_out, self.radius)
            t1, t_rest = t_proj[:, 0].unsqueeze(1), t_proj[:, 1:]
            t = self.create_masked_t((1 - self.mask[i]), t1, t_rest)
            # (1-b) \odot \tilde{x} \odot exp(s(b \odot \tilde{x}))
            x_pt_arg = expand_proj_dims((1 - self.mask[i]) * x * torch.exp(s))

            # (1-b) \odot \textnormal{PT}_{\textbf{o}\to t(b \odot \tilde{x})
            pt = parallel_transport_mu0(x_pt_arg, dst=t, radius=self.radius)
            preclamp_norm = pt.max()
            pt = clamp(pt, min=-max_clamp_norm, max=max_clamp_norm)
            if pt.max() == max_clamp_norm:
                preclamp_norm_list.append(preclamp_norm)
            x_t = exp_map(x=pt, at_point=t, radius=self.radius)
            log_det_J += _logdet(pt, self.radius, subdim=(self.mask[i]).sum())
            preclamp_norm = x_t.max()
            x_t = clamp(x_t, min=-max_clamp_norm, max=max_clamp_norm)
            if x_t.max() == max_clamp_norm:
                preclamp_norm_list.append(preclamp_norm)

            #\log_{\textbf{o}}(\textnormal{exp}_{t()}(\textnormal{PT}_{\textbf{o}\to t()))
            x_0_full = inverse_exp_map_mu0(x_t, self.radius)
            x_0 = x_0_full[..., 1:]
            log_det_J += logmap_logdet(x_0_full,
                                       self.radius,
                                       subdim=(self.mask[i]).sum())
            x = x_ + (1 - self.mask[i]) * x_0
            log_det_J += ((1 - self.mask[i]) * s).sum(dim=1)  # log det dx/du

            preclamp_norm = x.max()
            x = clamp(x, min=-max_clamp_norm, max=max_clamp_norm)
            if x.max() == max_clamp_norm:
                preclamp_norm_list.append(preclamp_norm)

        x_mu0 = expand_proj_dims(x)
        # Project back to Manifold
        x = exp_map_mu0(x_mu0, self.radius)
        log_det_J += _logdet(x_mu0, self.radius)

        self.preclamp_norm = torch.Tensor([
            sum(preclamp_norm_list) / len(preclamp_norm_list)
        ]) if preclamp_norm_list else self.preclamp_norm
        return x, log_det_J
Exemplo n.º 4
0
def train_potential_flow(flow_model, n_blocks, radius, target):
    flow_model = kwargs_flows[flow_model](n_blocks,
                                          2,
                                          128,
                                          1,
                                          layer_type='Linear',
                                          radius=torch.tensor(radius)).cuda()
    flow_opt = optim.Adam(flow_model.parameters(), lr=1e-2)

    sample_shape = torch.Size([10000])
    num_samples = torch.Size([256])
    mu_0_shape = torch.Size([1, 3])
    std_0_shape = torch.Size([1, 2])
    prior = HyperboloidWrappedNormal(radius,
                                     torch.zeros(mu_0_shape).cuda(),
                                     torch.ones(std_0_shape).cuda())
    train_loss_avg = []
    for epoch in range(0, 1000):
        flow_opt.zero_grad()
        z_0 = prior.rsample(num_samples).squeeze()
        z_0 = clamp(z_0, min=-max_clamp_norm, max=max_clamp_norm)
        q_log_prob = prior.log_prob(z_0)
        z_hyper, logdet = flow_model.inverse(z_0)
        z_hyper = clamp(z_hyper, min=-max_clamp_norm, max=max_clamp_norm)
        z_k = inverse_exp_map_mu0(z_hyper, radius)
        z_mu0 = z_k[..., 1:]
        logdet += logmap_logdet(z_k, radius)
        p_log_prob = -1 * target(z_mu0)
        loss = (q_log_prob - p_log_prob - logdet).mean()
        loss.backward()
        flow_opt.step()
        print("Epoch:{} Loss:{}".format(epoch, loss.item()))

    return flow_model
Exemplo n.º 5
0
    def encode(self, *args, **kwargs):
        """"""
        # The sample is already reparameterized
        node_feats, edge_index = args[0], args[1]
        z, self.__mu_h__, self.__std__, self.posterior_parts = self.encoder(*args, **kwargs)
        # TODO: Double check this masking
        self.mask = (self.__std__.sum(dim=-1) != 1).int().unsqueeze(1)
        z = self.mask * clamp(z, min=-max_clamp_norm, max=max_clamp_norm)
        if self.encoder.flow_model:
            self.encoder.flow_model.base_dist_mean = self.__mu_h__
            self.encoder.flow_model.base_dist_var = self.__std__
            z_k, sum_log_det_jac = self.encoder.flow_model.inverse(z, edge_index)
            self.sum_log_det_jac = sum_log_det_jac
            z_k = clamp(z_k, min=-max_clamp_norm, max=max_clamp_norm)
            # This is what gets used in KL Calculation as prior is
            # WrappedGaussian and should not be on the Tangent Space
            self.z_k = z_k
        else:
            self.z_k = z
            z_k = z

        if self.decoder_name not in ['fermi', 'tanh', 'distance', 'softmax']:
            # Log-map z back to \mathcal{T}_{\textbf{o}}\mathbb{H}
            z_mu0 = inverse_exp_map_mu0(z_k, self.encoder.radius)
            self.decoder_logdet = logmap_logdet(z_mu0, self.encoder.radius)
            return z, z_mu0

        # if self.deterministic:
            # mu_h = clamp(self.__mu_h__, min=-max_clamp_norm, max=max_clamp_norm)
            # return mu_h, mu_h

        return z, z_k
Exemplo n.º 6
0
 def forward(self, x):
     x_tangent_mu0 = inverse_exp_map_mu0(x, self.radius)
     output = exp_map_mu0(x_tangent_mu0.matmul(self.weight.t()),
                          self.radius)
     if self.use_bias:
         output = parallel_transport_mu0(self.bias, output, self.radius)
         output = exp_map(output, x)
     ret = output
     h = self.hyper_act(ret)
     return h
Exemplo n.º 7
0
 def forward(self, z, edge_index):
     z_mu0 = inverse_exp_map_mu0(z, self.radius)
     dist = -1*hyperboloid_dist(z[edge_index[0]], z[edge_index[1]],
                                self.radius).squeeze()
     r_gnn = self.r(z_mu0, edge_index)
     t_gnn = self.t(z_mu0, edge_index)
     inp_r = torch.cat((r_gnn[edge_index[0]],r_gnn[edge_index[1]]), dim=1)
     inp_t = torch.cat((t_gnn[edge_index[0]],t_gnn[edge_index[1]]), dim=1)
     r = self.r_mlp(inp_r).squeeze()
     t = self.t_mlp(inp_t).squeeze()
     probs = torch.sigmoid((dist - r) / t)
     return probs.squeeze()
Exemplo n.º 8
0
    def hyper_act(self, x):
        '''
        Op: \sigma(x)
        Input:
        x: A Hyperbolic Vector in H^n

        Output:
        sigma_x_mu0: A Hyperbolic Vector in H^n after \sigma(x)
        '''
        x_tangent_mu0 = inverse_exp_map_mu0(x, self.radius)
        sigma_x = F.relu(x_tangent_mu0)
        sigma_x_mu0 = exp_map_mu0(sigma_x, self.radius)
        return sigma_x_mu0
Exemplo n.º 9
0
 def forward_all(self, z, edge_index):
     adj = torch.eye(z.shape[0])
     z_mu0 = inverse_exp_map_mu0(z, self.radius)
     dist = -1*hyperboloid_dist(z[edge_index[0]], z[edge_index[1]], self.radius)
     r_gnn = self.r(z_mu0, edge_index)
     t_gnn = self.t(z_mu0, edge_index)
     inp_r = torch.cat((r_gnn[edge_index[0]],r_gnn[edge_index[1]]), dim=1)
     inp_t = torch.cat((t_gnn[edge_index[0]],t_gnn[edge_index[1]]), dim=1)
     r = self.r_mlp(inp_r).squeeze()
     t = self.t_mlp(inp_r).squeeze()
     probs = torch.sigmoid((dist - r) / t)
     for i in range(0, len(edge_index[0])):
         adj[edge_index[0][i]][edge_index[1][i]] = probs[i]
     return adj
Exemplo n.º 10
0
    def forward(self, x_hyper, edge_index=None):
        x = inverse_exp_map_mu0(x_hyper, self.radius)
        x_mu0 = x[..., 1:]
        log_det_J, z = x.new_zeros(x_mu0.shape[0]), x_mu0
        log_det_J = -1 * logmap_logdet(x, self.radius)
        for i in reversed(range(0, self.n_blocks)):
            z_ = self.mask[i] * z
            if self.layer_type != 'Linear':
                s = self.s[i](z_, edge_index)
                t_out = self.t[i](z_, edge_index)
            else:
                s = self.s[i](z_)
                t_out = self.t[i](z_)
            t_proj = proj_vec(t_out, self.radius)

            t1, t_rest = t_proj[:, 0].unsqueeze(1), t_proj[:, 1:]
            t = self.create_masked_t((1 - self.mask[i]), t1, t_rest)

            z_2 = expand_proj_dims((1 - self.mask[i]) * z)
            z_2 = clamp(z_2, min=-max_clamp_norm, max=max_clamp_norm)
            z_exp_2 = exp_map_mu0(z_2, self.radius)
            log_det_J -= _logdet(z_2, self.radius, subdim=(self.mask[i]).sum())

            z_exp_2 = clamp(z_exp_2, min=-max_clamp_norm, max=max_clamp_norm)
            z_inv_pt_arg = inverse_exp_map(x=z_exp_2,
                                           at_point=t,
                                           radius=self.radius)
            log_det_J -= logmap_logdet(z_inv_pt_arg,
                                       self.radius,
                                       subdim=(self.mask[i]).sum())

            z_inv_pt_arg = clamp(z_inv_pt_arg,
                                 min=-max_clamp_norm,
                                 max=max_clamp_norm)
            pt = inverse_parallel_transport_mu0(z_inv_pt_arg,
                                                src=t,
                                                radius=self.radius)
            pt = pt[..., 1:]

            z = (1 - self.mask[i]) * pt * torch.exp(-s) + z_
            log_det_J -= ((1 - self.mask[i]) * s).sum(dim=1)

        z_mu0 = expand_proj_dims(z)
        z = exp_map_mu0(z_mu0, self.radius)
        log_det_J -= _logdet(z_mu0, self.radius)
        return z, log_det_J
Exemplo n.º 11
0
    def forward(self, x):
        sum_log_det_jac = 0
        z, mu_h, std = self.encode(x)
        z = clamp(z, min=-max_clamp_norm, max=max_clamp_norm)
        ### Flow ###
        if self.flow_model:
            self.flow_model.base_dist_mean = mu_h
            self.flow_model.base_dist_var = std
            self.flow_model.radius = self.radius
            z_k, sum_log_det_jac = self.flow_model.inverse(z)
            z_k = clamp(z_k, min=-max_clamp_norm, max=max_clamp_norm)
        else:
            z_k = z

        kld = self.kl_loss(self.q_z, self.p_z, z, z_k, self.data)
        z_mu0 = inverse_exp_map_mu0(z_k, self.radius)
        # This is not really the same KL Divergence and can be negative
        kld = kld - sum_log_det_jac - logmap_logdet(z_mu0, self.radius)
        x_tilde = self.decode(z_mu0)
        return x_tilde, kld
Exemplo n.º 12
0
    def MC_log_likelihood(self, x):
        """
        :param x: Mini-batch of inputs.
        :param n: Number of MC samples
        :return: Monte Carlo estimate of log-likelihood.
        """
        n = self.K
        sample_shape = torch.Size([n])
        batch_size = x.shape[0]
        prob_shape = torch.Size([n, batch_size])

        x_encoded = self.encoder(x)
        mu, logvar = self.fc_mean(x_encoded), self.fc_logvar(x_encoded)
        mu = clamp(mu, min=-max_clamp_norm, max=max_clamp_norm)
        mu_h = exp_map_mu0(expand_proj_dims(mu), self.radius)

        # +eps prevents collapse
        std = F.softplus(logvar) + 1e-5
        q_z, p_z = self.reparametrize(mu_h, std)
        log_p_z = torch.zeros(prob_shape, device=x.device)
        log_q_z_x = torch.zeros(prob_shape, device=x.device)

        # Numerically more stable.
        z, log_q_z_x, log_p_z = self.rsample_log_probs(sample_shape, q_z, p_z)
        z = inverse_exp_map_mu0(z, self.radius)
        log_q_z_x = log_q_z_x - logmap_logdet(z, self.radius)

        x_mb_ = self.decode(z)
        x_orig = x.repeat((n, 1, 1))
        log_p_x_z = -self.recon_loss(x_mb_, x_orig).sum(dim=-1)

        assert log_p_x_z.shape == log_p_z.shape
        assert log_q_z_x.shape == log_p_z.shape
        joint = (log_p_x_z + log_p_z - log_q_z_x)
        log_p_x = joint.logsumexp(dim=0) - np.log(n)

        assert log_q_z_x.shape == log_p_z.shape
        mi = (log_q_z_x - log_p_z).logsumexp(dim=0) - np.log(n)

        return log_p_x, mi
Exemplo n.º 13
0
    def rsample_log_probs(
            self, sample_shape: torch.Size, q_z: HyperboloidWrappedNormal,
            p_z: HyperboloidWrappedNormal) -> Tuple[Tensor, Tensor, Tensor]:
        sum_log_det_jac = 0
        z, posterior_parts = q_z.rsample_with_parts(sample_shape)
        z = clamp(z, min=-max_clamp_norm, max=max_clamp_norm)
        if self.flow_model:
            z_k = z.view(-1, self.z_dim + 1)
            z_k, sum_log_det_jac = self.flow_model.inverse(z_k)
            z_k = clamp(z_k, min=-max_clamp_norm, max=max_clamp_norm)
            z_k = z_k.view(sample_shape[0], -1, self.z_dim + 1)
            sum_log_det_jac = sum_log_det_jac.view(sample_shape[0], -1)
        else:
            z_k = z

        z_mu0 = inverse_exp_map_mu0(z_k, self.radius)
        log_q_z_x, log_p_z_k = self._log_prob(q_z, p_z, z, z_k,
                                              posterior_parts)
        log_q_z_k_x = log_q_z_x - sum_log_det_jac - logmap_logdet(
            z_mu0, self.radius)
        log_p_z_k = log_p_z_k - logmap_logdet(z_mu0, self.radius)
        return z_mu0, log_q_z_k_x, log_p_z_k