def __init__(self): super(RNN_VAE, self).__init__() self.z_dim = 44 self.enc = RNNEncoder(out_dim=2 * self.z_dim) self.dec = RNNDecoder(in_dim=self.z_dim) self.beta = 0.01
def __init__(self, dataset='paired_mnist'): super(BiAAE, self).__init__() self.dataset = dataset if self.dataset == 'paired_mnist': self.z_dim = 16 self.joint_dim = 4 self.loss_rec_lambda_x = 10 self.loss_rec_lambda_y = 10 self.loss_normal_lambda = 0.3 self.loss_indep_lambda_x = 1 self.loss_indep_lambda_y = 1 self.loss_mse_lambda = 0.1 self.discr_steps = 1 self.gen_steps = 1 self.enc_x = MnistCNNEncoder(out_dim=self.z_dim) self.enc_y = MnistCNNEncoder(out_dim=self.z_dim) self.dec_x = MnistCNNDecoder(in_dim=self.z_dim) self.dec_y = MnistCNNDecoder(in_dim=self.z_dim) self.discr = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) elif self.dataset == 'lincs_rnn': self.z_dim = 20 self.joint_dim = 10 self.loss_rec_lambda_x = 5 self.loss_rec_lambda_y = 1 self.loss_normal_lambda = 0.5 self.loss_indep_lambda_x = 0.5 self.loss_indep_lambda_y = 0.5 self.loss_mse_lambda = 0.5 self.discr_steps = 3 self.gen_steps = 1 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim) self.enc_y = ExprDiffEncoder(out_dim=self.z_dim) rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim) self.dec_y = ExprDiffDecoder(in_dim=self.z_dim) self.discr = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) elif self.dataset == 'lincs_rnn_reverse': self.z_dim = 20 self.joint_dim = 10 self.loss_rec_lambda_x = 1 self.loss_rec_lambda_y = 0.2 self.loss_normal_lambda = 0.5 self.loss_indep_lambda_x = 0.5 self.loss_indep_lambda_y = 0.5 self.loss_mse_lambda = 0.5 self.discr_steps = 3 self.gen_steps = 1 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim) self.enc_x = ExprDiffEncoder(out_dim=self.z_dim) rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim) self.dec_x = ExprDiffDecoder(in_dim=self.z_dim) self.discr = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False)
def __init__(self, dataset='paired_mnist'): super(VCCA, self).__init__() self.dataset = dataset if self.dataset == 'paired_mnist': self.z_dim = 16 self.joint_dim = 4 self.enc_x = MnistCNNEncoder(out_dim=2 * (self.z_dim - self.joint_dim)) self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim) self.loss_rec_lambda_x = 10 self.loss_rec_lambda_y = 10 self.beta = 0.01 self.dec_x = MnistCNNDecoder(in_dim=self.z_dim) self.dec_y = MnistCNNDecoder(in_dim=self.z_dim) elif self.dataset == 'lincs_rnn': self.z_dim = 20 self.joint_dim = 10 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_x = FinetunedEncoder(rnn_1, out_dim=2 * (self.z_dim - self.joint_dim)) self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim) self.loss_rec_lambda_x = 5 self.loss_rec_lambda_y = 1 self.beta = 1 rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim) self.dec_y = ExprDiffDecoder(in_dim=self.z_dim) elif self.dataset == 'lincs_rnn_reverse': self.z_dim = 20 self.joint_dim = 10 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim) self.enc_x = ExprDiffEncoder(out_dim=2 * (self.z_dim - self.joint_dim)) self.loss_rec_lambda_x = 1 self.loss_rec_lambda_y = 0.2 self.beta = 1 rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim) self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)
class RNN_VAE(pl.LightningModule): def __init__(self): super(RNN_VAE, self).__init__() self.z_dim = 44 self.enc = RNNEncoder(out_dim=2 * self.z_dim) self.dec = RNNDecoder(in_dim=self.z_dim) self.beta = 0.01 @staticmethod def sample_repar_z(means, logvar): return means + torch.randn_like(means) * torch.exp(0.5 * logvar) @staticmethod def kl_div(means_q, logvar_q, means_p=None, logvar_p=None): if means_p is None: # prior is N(0, I) return -0.5 * torch.mean( torch.sum(1 + logvar_q - means_q.pow(2) - logvar_q.exp(), dim=-1)) else: return -0.5 * torch.mean( torch.sum(1 - logvar_p + logvar_q - (means_q.pow(2) + logvar_q.exp()) * (-logvar_p).exp(), dim=-1)) def training_step(self, batch, batch_nb): # pair of objects x, _ = batch # compute proposal distributions p_z_means, p_z_logvar = torch.split(self.enc(x), self.z_dim, -1) # sample z z_sample = self.sample_repar_z(p_z_means, p_z_logvar) # compute kl divergence kl = self.kl_div(p_z_means, p_z_logvar) # compute reconstrunction loss x_by_z_logprob = self.dec.get_log_prob(x=x, z=z_sample).mean() loss = (-x_by_z_logprob + self.beta * kl) return { 'loss': loss, 'log': { 'x_by_z_logprob': x_by_z_logprob, 'kl': kl } } def validation_step(self, batch, batch_nb): x, _ = batch # compute proposal distributions p_z_means, p_z_logvar = torch.split(self.enc(x), self.z_dim, -1) # sample z z_sample = self.sample_repar_z(p_z_means, p_z_logvar) sampled_sm = self.dec.sample(z_sample) valid_proc = len([ s for s in sampled_sm if Chem.MolFromSmiles(s) is not None ]) / len(sampled_sm) unique_proc = len(np.unique(sampled_sm)) / len(sampled_sm) eq_proc = len([ s for (x_ob, s) in zip(x, sampled_sm) if (Chem.MolFromSmiles(s) is not None) and (Chem.MolToSmiles(Chem.MolFromSmiles(s))) == x_ob ]) / len(sampled_sm) return {'valid': valid_proc, 'unique': unique_proc, 'equal': eq_proc} def validation_end(self, outputs): val_stats = {} val_stats['val_valid'] = np.array([x['valid'] for x in outputs]).mean() val_stats['val_unique'] = np.array([x['unique'] for x in outputs]).mean() val_stats['val_equal'] = np.array([x['equal'] for x in outputs]).mean() return {'log': val_stats} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=3e-4) @pl.data_loader def train_dataloader(self): return DataLoader(MolecularDataset('../data/train.txt', train=True), batch_size=512, shuffle=True, num_workers=10) @pl.data_loader def val_dataloader(self): return DataLoader(MolecularDataset('../data/train.txt', train=False), batch_size=512, shuffle=True, num_workers=10)
def __init__(self, dataset='paired_mnist'): super(CVAE, self).__init__() self.dataset = dataset if self.dataset == 'paired_mnist': self.z_dim = 8 self.loss_rec_lambda_x = 10 self.loss_rec_lambda_y = 10 self.beta = 0.01 self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim) self.enc_xy = JointEncoder(MnistCNNEncoder(out_dim=2 * self.z_dim, short_tail=True), MnistCNNEncoder(out_dim=2 * self.z_dim, short_tail=True), out_dim=2 * self.z_dim) self.dec_x_cond = ConditionedDecoder( dec=MnistCNNDecoder(in_dim=self.z_dim + self.z_dim), cond=MnistCNNEncoder(out_dim=self.z_dim, short_tail=True)) elif self.dataset == 'lincs_rnn': self.z_dim = 10 self.loss_rec_lambda_x = 5 self.loss_rec_lambda_y = 1 self.beta = 1 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) enc_x = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim) self.enc_xy = JointEncoder(enc_x, ExprDiffEncoder(out_dim=2 * self.z_dim), out_dim=2 * self.z_dim) self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim) rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) dec_x = FinetunedDecoder(rnn_2, in_dim=2 * self.z_dim) self.dec_x_cond = ConditionedDecoder( dec=dec_x, cond=ExprDiffEncoder(out_dim=self.z_dim)) elif self.dataset == 'lincs_rnn_reverse': self.z_dim = 10 self.loss_rec_lambda_x = 1 self.loss_rec_lambda_y = 0.2 self.beta = 1 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim) self.enc_xy = JointEncoder(ExprDiffEncoder(out_dim=2 * self.z_dim), FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim), out_dim=2 * self.z_dim) rnn_2 = RNNEncoder(out_dim=88) rnn_2.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.dec_x_cond = ConditionedDecoder( dec=ExprDiffDecoder(in_dim=2 * self.z_dim), cond=FinetunedEncoder(rnn_2, out_dim=self.z_dim))