def pseudo_hyperbolic_gaussian(z, mu_h, cov, version, vt=None, u=None): batch_size, n_h = mu_h.shape n = n_h - 1 mu0 = to_cuda_var(torch.zeros(batch_size, n)) v0 = torch.cat((to_cuda_var(torch.ones(batch_size, 1)), mu0), 1) # origin of the hyperbolic space # try not using inverse exp. mapping if vt is already known if vt is None and u is None: u = inv_exp_map(z, mu_h) v = parallel_transport(u, mu_h, v0) vt = v[:, 1:] logp_vt = (MultivariateNormal(mu0, cov).log_prob(vt)).view(-1, 1) else: logp_vt = (MultivariateNormal(mu0, cov).log_prob(vt)).view(-1, 1) r = lorentz_tangent_norm(u) if version == 1: alpha = -lorentz_product(v0, mu_h) log_det_proj_mu = n * (torch.log(torch.sinh(r)) - torch.log(r)) + torch.log( torch.cosh(r)) + torch.log(alpha) elif version == 2: log_det_proj_mu = (n - 1) * (torch.log(torch.sinh(r)) - torch.log(r)) logp_z = logp_vt - log_det_proj_mu return logp_vt, logp_z
def kl_loss(self, mean, logv, vt, u, z): batch_size, n_h = mean.shape n = n_h - 1 mu0 = to_cuda_var(torch.zeros(batch_size, n)) mu0_h = lorentz_mapping_origin(mu0) diag = to_cuda_var(torch.eye(n).repeat(batch_size, 1, 1)) cov = torch.exp(logv).unsqueeze(dim=2) * diag # posterior density _, logp_posterior_z = pseudo_hyperbolic_gaussian(z, mean, cov, version=2, vt=vt, u=u) if self.prior == 'Standard': _, logp_prior_z = pseudo_hyperbolic_gaussian(z, mu0_h, diag, version=2, vt=None, u=None) kl_loss = torch.sum(logp_posterior_z.squeeze() - logp_prior_z.squeeze()) return kl_loss
def lorentz_mapping(x): # if the input is the origin of the Euclidean space [batch_size, n] = x.shape # interpret x_t as an element of tangent space of the origin of hyperbolic space x_t = torch.cat((to_cuda_var(torch.zeros(batch_size, 1)), x), 1) # origin of the hyperbolic space v0 = torch.cat((to_cuda_var(torch.ones( batch_size, 1)), to_cuda_var(torch.zeros(batch_size, n))), 1) # exponential mapping z = exp_map(x_t, v0) return z
def mmd_loss(self, zq): # true standard normal distribution samples batch_size, n_h = zq.shape n = n_h - 1 mu0 = to_cuda_var(torch.zeros(batch_size, n)) mu0_h = lorentz_mapping_origin(mu0) logv = to_cuda_var(torch.zeros(batch_size, n)) vt, u, z = lorentz_sampling(mu0_h, logv) # compute mmd mmd = self.compute_mmd(z, zq) return mmd
def kl_loss(self, mean, logv, z): batch_size, n = mean.shape diag = to_cuda_var(torch.eye(n).repeat(batch_size, 1, 1)) cov = torch.exp(logv).unsqueeze(dim=-1) * diag # compute log probabilities of posterior z_posterior_pdf = MultivariateNormal(mean, cov) logp_posterior_z = z_posterior_pdf.log_prob(z) if self.prior == 'Standard': z_prior_pdf = MultivariateNormal(to_cuda_var(torch.zeros(n)), diag) logp_prior_z = z_prior_pdf.log_prob(z) kl_loss = torch.sum(logp_posterior_z.squeeze() - logp_prior_z.squeeze()) return kl_loss
def mmd_loss(self, zq): # true standard normal distribution samples true_samples = to_cuda_var(torch.randn([zq.shape[0], self.latent_size])) # compute mmd mmd = self.compute_mmd(true_samples, zq) return mmd
def lorentz_sampling(mu_h, logvar): [batch_size, n_h] = mu_h.shape n = n_h - 1 #step 1: Sample a vector (vt) from the Gaussian distribution N(0,COV) defined over R(n) mu0 = to_cuda_var(torch.zeros(batch_size, n)) std = torch.exp(0.5 * logvar) eps = torch.randn_like(mu0) vt = mu0 + std * eps # reparameterization trick #step 2: Interpret v as an element of tangent space of the origin of the hyperbolic space v0 = torch.cat((to_cuda_var(torch.ones( batch_size, 1)), to_cuda_var(torch.zeros(batch_size, n))), 1) v = torch.cat((to_cuda_var(torch.zeros(batch_size, 1)), vt), 1) #step 3: Parallel transport the vector v to u which belongs to the tangent space of the mu u = parallel_transport(v, v0, mu_h) # step 4: Map u to hyperbolic space by exponential mapping z = exp_map(u, mu_h) return vt, u, z
def reparameterize(self, hidden): # mean vector mean_z = self.hidden2mean(hidden) # logvar vector logv = self.hidden2logv(hidden) std = torch.exp(0.5 * logv) eps = to_cuda_var(torch.randn([mean_z.shape[0], self.latent_size])) z = mean_z + eps * std return mean_z, logv, z
def one_hot_embedding(self, input_sequence): embeddings = np.zeros((input_sequence.shape[0], input_sequence.shape[1], self.vocab_size), dtype=np.float32) for b, batch in enumerate(input_sequence): for t, char in enumerate(batch): if char.item() != 0: embeddings[b, t, char.item()] = 1 return to_cuda_var(torch.from_numpy(embeddings))
def vae_loss(self, batch, num_samples): batch_size = len(batch['drug_name']) input_sequence = batch['drug_inputs'] target_sequence = batch['drug_targets'] input_sequence_length = batch['drug_len'] # compute reconstruction loss sorted_lengths, sorted_idx = torch.sort( input_sequence_length, descending=True) # change input order input_sequence = input_sequence[sorted_idx] hidden = self.encoder( input_sequence, sorted_lengths) # hidden_factor, batch_size, hidden_size mean, logv, z = self.reparameterize(hidden) logp_drug = self.decoder(input_sequence, sorted_lengths, sorted_idx, z) target = target_sequence[:, :torch.max(input_sequence_length).item( )].contiguous().view(-1) logp = logp_drug.view(-1, logp_drug.size(2)) # reconstruction loss recon_loss = self.RECON(logp, target) / batch_size # kl loss if self.beta > 0.0: kl_loss = self.kl_loss(mean, logv, z) / batch_size else: kl_loss = to_cuda_var(torch.tensor(0.0)) # marginal kl loss if self.alpha > 0.0: mkl_loss = self.marginal_posterior_divergence( z, mean, logv, num_samples) / batch_size else: mkl_loss = to_cuda_var(torch.tensor(0.0)) # MMD loss, p(z) ~ standard normal distribution if self.gamma > 0.0: mmd_loss = self.mmd_loss(z) else: mmd_loss = to_cuda_var(torch.tensor(0.0)) return recon_loss, kl_loss, mkl_loss, mmd_loss
def forward(self, task, batch, num_samples): if task == 'vae': recon_loss, kl_loss, mkl_loss, mmd_loss = self.vae_loss( batch, num_samples) # SMILES recon. loss return recon_loss, kl_loss, mkl_loss, mmd_loss, to_cuda_var( torch.tensor(0.0)) elif task == 'atc': local_ranking_loss = self.ranking_loss( batch) # ATC local ranking loss return to_cuda_var(torch.tensor(0.0)), to_cuda_var( torch.tensor(0.0)), to_cuda_var( torch.tensor(0.0)), to_cuda_var( torch.tensor(0.0)), local_ranking_loss elif task == 'vae + atc': recon_loss, kl_loss, mkl_loss, mmd_loss = self.vae_loss( batch, num_samples) # SMILES recon. loss local_ranking_loss = self.ranking_loss( batch) # ATC local ranking loss return recon_loss, kl_loss, mkl_loss, mmd_loss, local_ranking_loss
def marginal_posterior_divergence(self, z, mean, logv, num_samples): batch_size, n = mean.shape diag = to_cuda_var(torch.eye(n).repeat(1, 1, 1)) logq_zb_lst = [] logp_zb_lst = [] for b in range(batch_size): zb = z[b, :].unsqueeze(0) mu_b = mean[b, :].unsqueeze(0) logv_b = logv[b, :].unsqueeze(0) diag_b = to_cuda_var(torch.eye(n).repeat(1, 1, 1)) cov_b = torch.exp(logv_b).unsqueeze(dim=2) * diag_b # removing b-th mean and logv zr = zb.repeat(batch_size - 1, 1) mu_r = torch.cat((mean[:b, :], mean[b + 1:, :])) logv_r = torch.cat((logv[:b, :], logv[b + 1:, :])) diag_r = to_cuda_var(torch.eye(n).repeat(batch_size - 1, 1, 1)) cov_r = torch.exp(logv_r).unsqueeze(dim=2) * diag_r # E[log q(zb)] = - H(q(z)) zb_xb_posterior_pdf = MultivariateNormal(mu_b, cov_b) logq_zb_xb = zb_xb_posterior_pdf.log_prob(zb) zb_xr_posterior_pdf = MultivariateNormal(mu_r, cov_r) logq_zb_xr = zb_xr_posterior_pdf.log_prob(zr) yb1 = logq_zb_xb - torch.log( to_cuda_var(torch.tensor(num_samples).float())) yb2 = logq_zb_xr + torch.log( to_cuda_var( torch.tensor((num_samples - 1) / ((batch_size - 1) * num_samples)).float())) yb = torch.cat([yb1, yb2], dim=0) logq_zb = torch.logsumexp(yb, dim=0) # E[log p(zb)] zb_prior_pdf = MultivariateNormal(to_cuda_var(torch.zeros(n)), diag) logp_zb = zb_prior_pdf.log_prob(zb) logq_zb_lst.append(logq_zb) logp_zb_lst.append(logp_zb) logq_zb = torch.stack(logq_zb_lst, dim=0) logp_zb = torch.stack(logp_zb_lst, dim=0).squeeze(-1) return (logq_zb - logp_zb).sum()
def lorentz_mapping_origin(x): batch_size, _ = x.shape return torch.cat((to_cuda_var(torch.ones(batch_size, 1)), x), 1)
def marginal_posterior_divergence(self, vt, u, z, mean, logv, num_samples): batch_size, n_h = mean.shape mu0 = to_cuda_var(torch.zeros(1, n_h - 1)) mu0_h = lorentz_mapping_origin(mu0) diag0 = to_cuda_var(torch.eye(n_h - 1).repeat(1, 1, 1)) logq_zb_lst = [] logp_zb_lst = [] for b in range(batch_size): vt_b = vt[b, :].unsqueeze(0) u_b = u[b, :].unsqueeze(0) zb = z[b, :].unsqueeze(0) mu_b = mean[b, :].unsqueeze(0) logv_b = logv[b, :].unsqueeze(0) diag_b = to_cuda_var(torch.eye(n_h - 1).repeat(1, 1, 1)) cov_b = torch.exp(logv_b).unsqueeze(dim=2) * diag_b # removing b-th mean and logv vt_r = vt_b.repeat(batch_size - 1, 1) u_r = u_b.repeat(batch_size - 1, 1) zr = zb.repeat(batch_size - 1, 1) mu_r = torch.cat((mean[:b, :], mean[b + 1:, :])) logv_r = torch.cat((logv[:b, :], logv[b + 1:, :])) diag_r = to_cuda_var( torch.eye(n_h - 1).repeat(batch_size - 1, 1, 1)) cov_r = torch.exp(logv_r).unsqueeze(dim=2) * diag_r # E[log q(zb)] = - H(q(z)) _, logq_zb_xb = pseudo_hyperbolic_gaussian(zb, mu_b, cov_b, version=2, vt=vt_b, u=u_b) _, logq_zb_xr = pseudo_hyperbolic_gaussian(zr, mu_r, cov_r, version=2, vt=vt_r, u=u_r) yb1 = logq_zb_xb - torch.log( to_cuda_var(torch.tensor(num_samples).float())) yb2 = logq_zb_xr + torch.log( to_cuda_var( torch.tensor((num_samples - 1) / ((batch_size - 1) * num_samples)).float())) yb = torch.cat([yb1, yb2], dim=0) logq_zb = torch.logsumexp(yb, dim=0) # E[log p(zb)] _, logp_zb = pseudo_hyperbolic_gaussian(zb, mu0_h, diag0, version=2, vt=None, u=None) logq_zb_lst.append(logq_zb) logp_zb_lst.append(logp_zb) logq_zb = torch.stack(logq_zb_lst, dim=0) logp_zb = torch.stack(logp_zb_lst, dim=0).squeeze(-1) return (logq_zb - logp_zb).sum()