Ejemplo n.º 1
0
def pseudo_hyperbolic_gaussian(z, mu_h, cov, version, vt=None, u=None):

    batch_size, n_h = mu_h.shape
    n = n_h - 1
    mu0 = to_cuda_var(torch.zeros(batch_size, n))
    v0 = torch.cat((to_cuda_var(torch.ones(batch_size, 1)), mu0),
                   1)  # origin of the hyperbolic space

    # try not using inverse exp. mapping if vt is already known
    if vt is None and u is None:
        u = inv_exp_map(z, mu_h)
        v = parallel_transport(u, mu_h, v0)
        vt = v[:, 1:]
        logp_vt = (MultivariateNormal(mu0, cov).log_prob(vt)).view(-1, 1)
    else:
        logp_vt = (MultivariateNormal(mu0, cov).log_prob(vt)).view(-1, 1)

    r = lorentz_tangent_norm(u)

    if version == 1:
        alpha = -lorentz_product(v0, mu_h)
        log_det_proj_mu = n * (torch.log(torch.sinh(r)) -
                               torch.log(r)) + torch.log(
                                   torch.cosh(r)) + torch.log(alpha)

    elif version == 2:
        log_det_proj_mu = (n - 1) * (torch.log(torch.sinh(r)) - torch.log(r))

    logp_z = logp_vt - log_det_proj_mu

    return logp_vt, logp_z
Ejemplo n.º 2
0
    def kl_loss(self, mean, logv, vt, u, z):
        batch_size, n_h = mean.shape
        n = n_h - 1
        mu0 = to_cuda_var(torch.zeros(batch_size, n))
        mu0_h = lorentz_mapping_origin(mu0)
        diag = to_cuda_var(torch.eye(n).repeat(batch_size, 1, 1))
        cov = torch.exp(logv).unsqueeze(dim=2) * diag

        # posterior density
        _, logp_posterior_z = pseudo_hyperbolic_gaussian(z,
                                                         mean,
                                                         cov,
                                                         version=2,
                                                         vt=vt,
                                                         u=u)

        if self.prior == 'Standard':
            _, logp_prior_z = pseudo_hyperbolic_gaussian(z,
                                                         mu0_h,
                                                         diag,
                                                         version=2,
                                                         vt=None,
                                                         u=None)
            kl_loss = torch.sum(logp_posterior_z.squeeze() -
                                logp_prior_z.squeeze())
        return kl_loss
Ejemplo n.º 3
0
def lorentz_mapping(x):
    # if the input is the origin of the Euclidean space
    [batch_size, n] = x.shape
    # interpret x_t as an element of tangent space of the origin of hyperbolic space
    x_t = torch.cat((to_cuda_var(torch.zeros(batch_size, 1)), x), 1)
    # origin of the hyperbolic space
    v0 = torch.cat((to_cuda_var(torch.ones(
        batch_size, 1)), to_cuda_var(torch.zeros(batch_size, n))), 1)
    # exponential mapping
    z = exp_map(x_t, v0)
    return z
Ejemplo n.º 4
0
 def mmd_loss(self, zq):
     # true standard normal distribution samples
     batch_size, n_h = zq.shape
     n = n_h - 1
     mu0 = to_cuda_var(torch.zeros(batch_size, n))
     mu0_h = lorentz_mapping_origin(mu0)
     logv = to_cuda_var(torch.zeros(batch_size, n))
     vt, u, z = lorentz_sampling(mu0_h, logv)
     # compute mmd
     mmd = self.compute_mmd(z, zq)
     return mmd
Ejemplo n.º 5
0
    def kl_loss(self, mean, logv, z):
        batch_size, n = mean.shape
        diag = to_cuda_var(torch.eye(n).repeat(batch_size, 1, 1))
        cov = torch.exp(logv).unsqueeze(dim=-1) * diag

        # compute log probabilities of posterior
        z_posterior_pdf = MultivariateNormal(mean, cov)
        logp_posterior_z = z_posterior_pdf.log_prob(z)

        if self.prior == 'Standard':
            z_prior_pdf = MultivariateNormal(to_cuda_var(torch.zeros(n)), diag)
            logp_prior_z = z_prior_pdf.log_prob(z)
            kl_loss = torch.sum(logp_posterior_z.squeeze() -
                                logp_prior_z.squeeze())
        return kl_loss
Ejemplo n.º 6
0
 def mmd_loss(self, zq):
     # true standard normal distribution samples
     true_samples = to_cuda_var(torch.randn([zq.shape[0],
                                             self.latent_size]))
     # compute mmd
     mmd = self.compute_mmd(true_samples, zq)
     return mmd
Ejemplo n.º 7
0
def lorentz_sampling(mu_h, logvar):
    [batch_size, n_h] = mu_h.shape
    n = n_h - 1
    #step 1: Sample a vector (vt) from the Gaussian distribution N(0,COV) defined over R(n)
    mu0 = to_cuda_var(torch.zeros(batch_size, n))
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(mu0)
    vt = mu0 + std * eps  # reparameterization trick
    #step 2: Interpret v as an element of tangent space of the origin of the hyperbolic space
    v0 = torch.cat((to_cuda_var(torch.ones(
        batch_size, 1)), to_cuda_var(torch.zeros(batch_size, n))), 1)
    v = torch.cat((to_cuda_var(torch.zeros(batch_size, 1)), vt), 1)
    #step 3: Parallel transport the vector v to u which belongs to the tangent space of the mu
    u = parallel_transport(v, v0, mu_h)
    # step 4: Map u to hyperbolic space by exponential mapping
    z = exp_map(u, mu_h)
    return vt, u, z
Ejemplo n.º 8
0
 def reparameterize(self, hidden):
     # mean vector
     mean_z = self.hidden2mean(hidden)
     # logvar vector
     logv = self.hidden2logv(hidden)
     std = torch.exp(0.5 * logv)
     eps = to_cuda_var(torch.randn([mean_z.shape[0], self.latent_size]))
     z = mean_z + eps * std
     return mean_z, logv, z
Ejemplo n.º 9
0
 def one_hot_embedding(self, input_sequence):
     embeddings = np.zeros((input_sequence.shape[0],
                            input_sequence.shape[1], self.vocab_size),
                           dtype=np.float32)
     for b, batch in enumerate(input_sequence):
         for t, char in enumerate(batch):
             if char.item() != 0:
                 embeddings[b, t, char.item()] = 1
     return to_cuda_var(torch.from_numpy(embeddings))
Ejemplo n.º 10
0
    def vae_loss(self, batch, num_samples):
        batch_size = len(batch['drug_name'])
        input_sequence = batch['drug_inputs']
        target_sequence = batch['drug_targets']
        input_sequence_length = batch['drug_len']

        # compute reconstruction loss
        sorted_lengths, sorted_idx = torch.sort(
            input_sequence_length, descending=True)  # change input order
        input_sequence = input_sequence[sorted_idx]

        hidden = self.encoder(
            input_sequence,
            sorted_lengths)  # hidden_factor, batch_size, hidden_size
        mean, logv, z = self.reparameterize(hidden)
        logp_drug = self.decoder(input_sequence, sorted_lengths, sorted_idx, z)

        target = target_sequence[:, :torch.max(input_sequence_length).item(
        )].contiguous().view(-1)
        logp = logp_drug.view(-1, logp_drug.size(2))

        # reconstruction loss
        recon_loss = self.RECON(logp, target) / batch_size
        # kl loss
        if self.beta > 0.0:
            kl_loss = self.kl_loss(mean, logv, z) / batch_size
        else:
            kl_loss = to_cuda_var(torch.tensor(0.0))
        # marginal kl loss
        if self.alpha > 0.0:
            mkl_loss = self.marginal_posterior_divergence(
                z, mean, logv, num_samples) / batch_size
        else:
            mkl_loss = to_cuda_var(torch.tensor(0.0))
        # MMD loss, p(z) ~ standard normal distribution
        if self.gamma > 0.0:
            mmd_loss = self.mmd_loss(z)
        else:
            mmd_loss = to_cuda_var(torch.tensor(0.0))
        return recon_loss, kl_loss, mkl_loss, mmd_loss
Ejemplo n.º 11
0
    def forward(self, task, batch, num_samples):

        if task == 'vae':
            recon_loss, kl_loss, mkl_loss, mmd_loss = self.vae_loss(
                batch, num_samples)  # SMILES recon. loss
            return recon_loss, kl_loss, mkl_loss, mmd_loss, to_cuda_var(
                torch.tensor(0.0))

        elif task == 'atc':
            local_ranking_loss = self.ranking_loss(
                batch)  # ATC local ranking loss
            return to_cuda_var(torch.tensor(0.0)), to_cuda_var(
                torch.tensor(0.0)), to_cuda_var(
                    torch.tensor(0.0)), to_cuda_var(
                        torch.tensor(0.0)), local_ranking_loss

        elif task == 'vae + atc':
            recon_loss, kl_loss, mkl_loss, mmd_loss = self.vae_loss(
                batch, num_samples)  # SMILES recon. loss
            local_ranking_loss = self.ranking_loss(
                batch)  # ATC local ranking loss
            return recon_loss, kl_loss, mkl_loss, mmd_loss, local_ranking_loss
Ejemplo n.º 12
0
    def marginal_posterior_divergence(self, z, mean, logv, num_samples):
        batch_size, n = mean.shape
        diag = to_cuda_var(torch.eye(n).repeat(1, 1, 1))

        logq_zb_lst = []
        logp_zb_lst = []
        for b in range(batch_size):
            zb = z[b, :].unsqueeze(0)
            mu_b = mean[b, :].unsqueeze(0)
            logv_b = logv[b, :].unsqueeze(0)
            diag_b = to_cuda_var(torch.eye(n).repeat(1, 1, 1))
            cov_b = torch.exp(logv_b).unsqueeze(dim=2) * diag_b

            # removing b-th mean and logv
            zr = zb.repeat(batch_size - 1, 1)
            mu_r = torch.cat((mean[:b, :], mean[b + 1:, :]))
            logv_r = torch.cat((logv[:b, :], logv[b + 1:, :]))
            diag_r = to_cuda_var(torch.eye(n).repeat(batch_size - 1, 1, 1))
            cov_r = torch.exp(logv_r).unsqueeze(dim=2) * diag_r

            # E[log q(zb)] = - H(q(z))
            zb_xb_posterior_pdf = MultivariateNormal(mu_b, cov_b)
            logq_zb_xb = zb_xb_posterior_pdf.log_prob(zb)

            zb_xr_posterior_pdf = MultivariateNormal(mu_r, cov_r)
            logq_zb_xr = zb_xr_posterior_pdf.log_prob(zr)

            yb1 = logq_zb_xb - torch.log(
                to_cuda_var(torch.tensor(num_samples).float()))
            yb2 = logq_zb_xr + torch.log(
                to_cuda_var(
                    torch.tensor((num_samples - 1) /
                                 ((batch_size - 1) * num_samples)).float()))
            yb = torch.cat([yb1, yb2], dim=0)
            logq_zb = torch.logsumexp(yb, dim=0)

            # E[log p(zb)]
            zb_prior_pdf = MultivariateNormal(to_cuda_var(torch.zeros(n)),
                                              diag)
            logp_zb = zb_prior_pdf.log_prob(zb)

            logq_zb_lst.append(logq_zb)
            logp_zb_lst.append(logp_zb)

        logq_zb = torch.stack(logq_zb_lst, dim=0)
        logp_zb = torch.stack(logp_zb_lst, dim=0).squeeze(-1)

        return (logq_zb - logp_zb).sum()
Ejemplo n.º 13
0
def lorentz_mapping_origin(x):
    batch_size, _ = x.shape
    return torch.cat((to_cuda_var(torch.ones(batch_size, 1)), x), 1)
Ejemplo n.º 14
0
    def marginal_posterior_divergence(self, vt, u, z, mean, logv, num_samples):
        batch_size, n_h = mean.shape

        mu0 = to_cuda_var(torch.zeros(1, n_h - 1))
        mu0_h = lorentz_mapping_origin(mu0)
        diag0 = to_cuda_var(torch.eye(n_h - 1).repeat(1, 1, 1))

        logq_zb_lst = []
        logp_zb_lst = []
        for b in range(batch_size):
            vt_b = vt[b, :].unsqueeze(0)
            u_b = u[b, :].unsqueeze(0)
            zb = z[b, :].unsqueeze(0)
            mu_b = mean[b, :].unsqueeze(0)
            logv_b = logv[b, :].unsqueeze(0)
            diag_b = to_cuda_var(torch.eye(n_h - 1).repeat(1, 1, 1))
            cov_b = torch.exp(logv_b).unsqueeze(dim=2) * diag_b

            # removing b-th mean and logv
            vt_r = vt_b.repeat(batch_size - 1, 1)
            u_r = u_b.repeat(batch_size - 1, 1)
            zr = zb.repeat(batch_size - 1, 1)
            mu_r = torch.cat((mean[:b, :], mean[b + 1:, :]))
            logv_r = torch.cat((logv[:b, :], logv[b + 1:, :]))
            diag_r = to_cuda_var(
                torch.eye(n_h - 1).repeat(batch_size - 1, 1, 1))
            cov_r = torch.exp(logv_r).unsqueeze(dim=2) * diag_r

            # E[log q(zb)] = - H(q(z))
            _, logq_zb_xb = pseudo_hyperbolic_gaussian(zb,
                                                       mu_b,
                                                       cov_b,
                                                       version=2,
                                                       vt=vt_b,
                                                       u=u_b)
            _, logq_zb_xr = pseudo_hyperbolic_gaussian(zr,
                                                       mu_r,
                                                       cov_r,
                                                       version=2,
                                                       vt=vt_r,
                                                       u=u_r)

            yb1 = logq_zb_xb - torch.log(
                to_cuda_var(torch.tensor(num_samples).float()))
            yb2 = logq_zb_xr + torch.log(
                to_cuda_var(
                    torch.tensor((num_samples - 1) /
                                 ((batch_size - 1) * num_samples)).float()))
            yb = torch.cat([yb1, yb2], dim=0)
            logq_zb = torch.logsumexp(yb, dim=0)

            # E[log p(zb)]
            _, logp_zb = pseudo_hyperbolic_gaussian(zb,
                                                    mu0_h,
                                                    diag0,
                                                    version=2,
                                                    vt=None,
                                                    u=None)

            logq_zb_lst.append(logq_zb)
            logp_zb_lst.append(logp_zb)

        logq_zb = torch.stack(logq_zb_lst, dim=0)
        logp_zb = torch.stack(logp_zb_lst, dim=0).squeeze(-1)

        return (logq_zb - logp_zb).sum()