Example #1
0
    def __init__(self):
        super(RNN_VAE, self).__init__()
        self.z_dim = 44

        self.enc = RNNEncoder(out_dim=2 * self.z_dim)
        self.dec = RNNDecoder(in_dim=self.z_dim)

        self.beta = 0.01
Example #2
0
    def __init__(self, dataset='paired_mnist'):
        super(BiAAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.loss_normal_lambda = 0.3
            self.loss_indep_lambda_x = 1
            self.loss_indep_lambda_y = 1
            self.loss_mse_lambda = 0.1

            self.discr_steps = 1
            self.gen_steps = 1

            self.enc_x = MnistCNNEncoder(out_dim=self.z_dim)
            self.enc_y = MnistCNNEncoder(out_dim=self.z_dim)

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_y = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_x = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
Example #3
0
    def __init__(self, dataset='paired_mnist'):
        super(VCCA, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.enc_x = MnistCNNEncoder(out_dim=2 *
                                         (self.z_dim - self.joint_dim))
            self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim)

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.beta = 0.01

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1,
                                          out_dim=2 *
                                          (self.z_dim - self.joint_dim))

            self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim)

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.beta = 1

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)
            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_x = ExprDiffEncoder(out_dim=2 *
                                         (self.z_dim - self.joint_dim))

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.beta = 1

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)
            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)
Example #4
0
class RNN_VAE(pl.LightningModule):
    def __init__(self):
        super(RNN_VAE, self).__init__()
        self.z_dim = 44

        self.enc = RNNEncoder(out_dim=2 * self.z_dim)
        self.dec = RNNDecoder(in_dim=self.z_dim)

        self.beta = 0.01

    @staticmethod
    def sample_repar_z(means, logvar):
        return means + torch.randn_like(means) * torch.exp(0.5 * logvar)

    @staticmethod
    def kl_div(means_q, logvar_q, means_p=None, logvar_p=None):
        if means_p is None:  # prior is N(0, I)
            return -0.5 * torch.mean(
                torch.sum(1 + logvar_q - means_q.pow(2) - logvar_q.exp(),
                          dim=-1))
        else:
            return -0.5 * torch.mean(
                torch.sum(1 - logvar_p + logvar_q -
                          (means_q.pow(2) + logvar_q.exp()) *
                          (-logvar_p).exp(),
                          dim=-1))

    def training_step(self, batch, batch_nb):
        # pair of objects
        x, _ = batch

        # compute proposal distributions
        p_z_means, p_z_logvar = torch.split(self.enc(x), self.z_dim, -1)

        # sample z
        z_sample = self.sample_repar_z(p_z_means, p_z_logvar)

        # compute kl divergence
        kl = self.kl_div(p_z_means, p_z_logvar)

        # compute reconstrunction loss
        x_by_z_logprob = self.dec.get_log_prob(x=x, z=z_sample).mean()

        loss = (-x_by_z_logprob + self.beta * kl)

        return {
            'loss': loss,
            'log': {
                'x_by_z_logprob': x_by_z_logprob,
                'kl': kl
            }
        }

    def validation_step(self, batch, batch_nb):
        x, _ = batch

        # compute proposal distributions
        p_z_means, p_z_logvar = torch.split(self.enc(x), self.z_dim, -1)

        # sample z
        z_sample = self.sample_repar_z(p_z_means, p_z_logvar)

        sampled_sm = self.dec.sample(z_sample)

        valid_proc = len([
            s for s in sampled_sm if Chem.MolFromSmiles(s) is not None
        ]) / len(sampled_sm)
        unique_proc = len(np.unique(sampled_sm)) / len(sampled_sm)
        eq_proc = len([
            s for (x_ob, s) in zip(x, sampled_sm)
            if (Chem.MolFromSmiles(s) is not None) and
            (Chem.MolToSmiles(Chem.MolFromSmiles(s))) == x_ob
        ]) / len(sampled_sm)

        return {'valid': valid_proc, 'unique': unique_proc, 'equal': eq_proc}

    def validation_end(self, outputs):
        val_stats = {}

        val_stats['val_valid'] = np.array([x['valid'] for x in outputs]).mean()
        val_stats['val_unique'] = np.array([x['unique']
                                            for x in outputs]).mean()
        val_stats['val_equal'] = np.array([x['equal'] for x in outputs]).mean()

        return {'log': val_stats}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=3e-4)

    @pl.data_loader
    def train_dataloader(self):
        return DataLoader(MolecularDataset('../data/train.txt', train=True),
                          batch_size=512,
                          shuffle=True,
                          num_workers=10)

    @pl.data_loader
    def val_dataloader(self):
        return DataLoader(MolecularDataset('../data/train.txt', train=False),
                          batch_size=512,
                          shuffle=True,
                          num_workers=10)
Example #5
0
    def __init__(self, dataset='paired_mnist'):
        super(CVAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 8

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.beta = 0.01

            self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(MnistCNNEncoder(out_dim=2 * self.z_dim,
                                                       short_tail=True),
                                       MnistCNNEncoder(out_dim=2 * self.z_dim,
                                                       short_tail=True),
                                       out_dim=2 * self.z_dim)

            self.dec_x_cond = ConditionedDecoder(
                dec=MnistCNNDecoder(in_dim=self.z_dim + self.z_dim),
                cond=MnistCNNEncoder(out_dim=self.z_dim, short_tail=True))
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.beta = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            enc_x = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(enc_x,
                                       ExprDiffEncoder(out_dim=2 * self.z_dim),
                                       out_dim=2 * self.z_dim)

            self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            dec_x = FinetunedDecoder(rnn_2, in_dim=2 * self.z_dim)
            self.dec_x_cond = ConditionedDecoder(
                dec=dec_x, cond=ExprDiffEncoder(out_dim=self.z_dim))

        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.beta = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(ExprDiffEncoder(out_dim=2 * self.z_dim),
                                       FinetunedEncoder(rnn_1,
                                                        out_dim=2 *
                                                        self.z_dim),
                                       out_dim=2 * self.z_dim)

            rnn_2 = RNNEncoder(out_dim=88)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))

            self.dec_x_cond = ConditionedDecoder(
                dec=ExprDiffDecoder(in_dim=2 * self.z_dim),
                cond=FinetunedEncoder(rnn_2, out_dim=self.z_dim))