예제 #1
0
    def __init__(self, dataset='paired_mnist'):
        super(BiAAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.loss_normal_lambda = 0.3
            self.loss_indep_lambda_x = 1
            self.loss_indep_lambda_y = 1
            self.loss_mse_lambda = 0.1

            self.discr_steps = 1
            self.gen_steps = 1

            self.enc_x = MnistCNNEncoder(out_dim=self.z_dim)
            self.enc_y = MnistCNNEncoder(out_dim=self.z_dim)

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_y = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_x = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
예제 #2
0
class BiAAE(pl.LightningModule):
    def __init__(self, dataset='paired_mnist'):
        super(BiAAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.loss_normal_lambda = 0.3
            self.loss_indep_lambda_x = 1
            self.loss_indep_lambda_y = 1
            self.loss_mse_lambda = 0.1

            self.discr_steps = 1
            self.gen_steps = 1

            self.enc_x = MnistCNNEncoder(out_dim=self.z_dim)
            self.enc_y = MnistCNNEncoder(out_dim=self.z_dim)

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_y = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_x = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

    # ------------------------------------------------------------------------
    #               TRAINING
    def get_latents(self, batch):
        # pair of objects
        x, y = batch

        z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1)
        z_x = torch.randn_like(z_y)

        return torch.cat((z_x, s_y), 1)

    def get_log_p_x_by_y(self, batch):
        return self.dec_x.get_log_prob(batch[0], self.get_latents(batch))

    def restore(self, batch):
        # pair of objects
        x, y = batch

        # compute encoder outputs and split them into joint and exclusive parts
        z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1)
        z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1)

        x_rest = self.dec_x.sample(torch.cat((z_x, s_x), 1))
        y_rest = self.dec_y.sample(torch.cat((z_y, s_y), 1))

        return (x_rest, y_rest)

    def sample(self, y):
        # sample z
        z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1)
        z_x = torch.randn_like(z_y)

        sampled_x = self.dec_x.sample(z=torch.cat((z_x, s_y), 1))
        return sampled_x

    def sample_y(self, x):
        # sample y
        z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1)
        z_y = torch.randn_like(z_x)

        sampled_x = self.dec_y.sample(z=torch.cat((z_y, s_x), 1))
        return sampled_x

    def training_step(self, batch, batch_nb, optimizer_i):
        # pair of objects
        x, y = batch

        # compute encoder outputs and split them into joint and exclusive parts
        z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1)
        z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1)

        if optimizer_i == 0:  # GENERATOR LOSS
            # Reconstruction losses
            loss_x_rec = -(self.dec_x.get_log_prob(x, torch.cat(
                (z_x, s_y), 1)).mean() +
                           self.dec_x.get_log_prob(x, torch.cat(
                               (z_x, s_x), 1)).mean()) / 2

            loss_y_rec = -(self.dec_y.get_log_prob(y, torch.cat(
                (z_y, s_y), 1)).mean() +
                           self.dec_y.get_log_prob(y, torch.cat(
                               (z_y, s_x), 1)).mean()) / 2

            # compute mse between common parts
            loss_mse = torch.norm(s_x - s_y, p=2, dim=-1).mean()

            # run discriminators
            discr_outputs = self.discr(
                torch.cat((torch.cat(
                    (z_x, s_x, z_y), dim=-1), torch.cat(
                        (z_x, s_y, z_y), dim=-1)),
                          dim=0))
            loss_norm = nn.BCEWithLogitsLoss()(discr_outputs,
                                               torch.ones_like(discr_outputs))

            discr_outputs_x = self.discr_indep_x(
                torch.cat((z_x, s_y.detach(), z_y.detach()), dim=-1))
            loss_indep_x = nn.BCEWithLogitsLoss()(
                discr_outputs_x, torch.ones_like(discr_outputs_x))

            discr_outputs_y = self.discr_indep_y(
                torch.cat((z_x.detach(), s_x.detach(), z_y), dim=-1))
            loss_indep_y = nn.BCEWithLogitsLoss()(
                discr_outputs_y, torch.ones_like(discr_outputs_y))

            g_loss = (loss_x_rec * self.loss_rec_lambda_x +
                      loss_y_rec * self.loss_rec_lambda_y +
                      loss_norm * self.loss_normal_lambda +
                      loss_indep_x * self.loss_indep_lambda_x +
                      loss_indep_y * self.loss_indep_lambda_y +
                      loss_mse * self.loss_mse_lambda)

            return {
                'loss': g_loss,
                'log': {
                    'loss_g': g_loss,
                    'x_rec': loss_x_rec,
                    'y_rec': loss_y_rec,
                    'loss_norm': loss_norm,
                    'loss_indep_x': loss_indep_x,
                    'loss_indep_y': loss_indep_y,
                    'loss_mse': loss_mse
                }
            }

        elif optimizer_i == 1:  # DISCRIMINATOR LOSS
            z_x = z_x.detach()
            s_x = s_x.detach()
            s_y = s_y.detach()
            z_y = z_y.detach()

            # normal noise loss
            real_inputs_1 = torch.cat((torch.cat(
                (z_x, s_x, z_y), dim=-1), torch.cat((z_x, s_y, z_y), dim=-1)),
                                      dim=0)
            real_dec_out_1 = self.discr(real_inputs_1)

            fake_inputs_1 = torch.randn_like(real_inputs_1)
            fake_dec_out_1 = self.discr(fake_inputs_1)

            probs_1 = torch.cat((real_dec_out_1, fake_dec_out_1), 0)
            targets_1 = torch.cat((torch.zeros_like(real_dec_out_1),
                                   torch.ones_like(fake_dec_out_1)), 0)

            d_loss_normal = nn.BCEWithLogitsLoss()(probs_1, targets_1)

            # independece loss
            real_inputs_2 = torch.cat((z_x, s_y, z_y), dim=-1)
            real_dec_out_2 = self.discr_indep_x(real_inputs_2)

            real_input_shuffled_2 = torch.cat(
                (z_x[np.random.permutation(z_x.shape[0])], s_y, z_y), dim=-1)
            fake_dec_out_2 = self.discr_indep_x(real_input_shuffled_2)

            probs_2 = torch.cat((real_dec_out_2, fake_dec_out_2), 0)
            targets_2 = torch.cat((torch.zeros_like(real_dec_out_2),
                                   torch.ones_like(fake_dec_out_2)), 0)

            d_loss_indep_x = nn.BCEWithLogitsLoss()(probs_2, targets_2)

            #-----------------------

            real_inputs_3 = torch.cat((z_x, s_x, z_y), dim=-1)
            real_dec_out_3 = self.discr_indep_y(real_inputs_3)

            real_input_shuffled_3 = torch.cat(
                (z_x, s_x, z_y[np.random.permutation(z_y.shape[0])]), dim=-1)
            fake_dec_out_3 = self.discr_indep_y(real_input_shuffled_3)

            probs_3 = torch.cat((real_dec_out_3, fake_dec_out_3), 0)
            targets_3 = torch.cat((torch.zeros_like(real_dec_out_3),
                                   torch.ones_like(fake_dec_out_3)), 0)

            d_loss_indep_y = nn.BCEWithLogitsLoss()(probs_3, targets_3)

            return {
                'loss': d_loss_normal + d_loss_indep_x + d_loss_indep_y,
                'log': {
                    'loss_d_normal': d_loss_normal,
                    'loss_d_indep_x': d_loss_indep_x,
                    'loss_d_indep_y': d_loss_indep_y
                }
            }

    def configure_optimizers(self):
        gen_params = torch.nn.ModuleList(
            [self.enc_x, self.dec_x, self.enc_y, self.dec_y])
        discr_params = torch.nn.ModuleList(
            [self.discr_indep_x, self.discr_indep_y, self.discr])

        gen_optim = torch.optim.Adam(gen_params.parameters(),
                                     lr=3e-4,
                                     betas=(0.5, 0.9))
        discr_optim = torch.optim.Adam(discr_params.parameters(),
                                       lr=3e-4,
                                       betas=(0.5, 0.9))

        discriminator_sched = StepLR(discr_optim, step_size=1000, gamma=0.99)

        return [gen_optim, discr_optim], [discriminator_sched]

    def zero_grad(self):
        self.enc_x.zero_grad()
        self.dec_x.zero_grad()
        self.enc_y.zero_grad()
        self.dec_y.zero_grad()

        self.discr.zero_grad()
        self.discr_indep_x.zero_grad()
        self.discr_indep_y.zero_grad()

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i,
                       optimizer_closure):
        discr_step = (batch_nb % (self.discr_steps + self.gen_steps)) < \
                     self.discr_steps

        gen_step = (not discr_step)

        if optimizer_i == 0:
            if gen_step:
                optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()

        if optimizer_i == 1:
            if discr_step:
                optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()

        if optimizer_i > 1:
            optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()
예제 #3
0
    def __init__(self, dataset='paired_mnist'):
        super(VCCA, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.enc_x = MnistCNNEncoder(out_dim=2 *
                                         (self.z_dim - self.joint_dim))
            self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim)

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.beta = 0.01

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1,
                                          out_dim=2 *
                                          (self.z_dim - self.joint_dim))

            self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim)

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.beta = 1

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)
            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_x = ExprDiffEncoder(out_dim=2 *
                                         (self.z_dim - self.joint_dim))

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.beta = 1

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)
            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)
예제 #4
0
class VCCA(pl.LightningModule):
    def __init__(self, dataset='paired_mnist'):
        super(VCCA, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.enc_x = MnistCNNEncoder(out_dim=2 *
                                         (self.z_dim - self.joint_dim))
            self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim)

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.beta = 0.01

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1,
                                          out_dim=2 *
                                          (self.z_dim - self.joint_dim))

            self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim)

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.beta = 1

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)
            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_x = ExprDiffEncoder(out_dim=2 *
                                         (self.z_dim - self.joint_dim))

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.beta = 1

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)
            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

    @staticmethod
    def sample_repar_z(means, logvar):
        return means + torch.randn_like(means) * torch.exp(0.5 * logvar)

    @staticmethod
    def kl_div(means_q, logvar_q, means_p=None, logvar_p=None):
        if means_p is None:  # prior is N(0, I)
            return -0.5 * torch.mean(
                torch.sum(1 + logvar_q - means_q.pow(2) - logvar_q.exp(),
                          dim=-1))
        else:
            return -0.5 * torch.mean(
                torch.sum(1 - logvar_p + logvar_q -
                          (means_q.pow(2) + logvar_q.exp()) *
                          (-logvar_p).exp(),
                          dim=-1))

    def get_latents(self, batch):
        # pair of objects
        x, y = batch

        # compute proposal distributions
        p_z_y_means, p_z_y_logvar = torch.split(self.enc_y(y), self.z_dim, -1)

        # sample z
        z_y_sample = self.sample_repar_z(p_z_y_means, p_z_y_logvar)

        return z_y_sample

    def get_log_p_x_by_y(self, batch):
        return self.dec_x.get_log_prob(batch[0], self.get_latents(batch))

    def restore(self, batch):
        # pair of objects
        x, y = batch

        # compute encoder outputs and split them into joint and exclusive parts
        p_z_x_means, p_z_x_logvar = torch.split(self.enc_x(x),
                                                self.z_dim - self.joint_dim,
                                                -1)
        p_zs_y_means, p_zs_y_logvar = torch.split(self.enc_y(y), self.z_dim,
                                                  -1)
        p_z_y_means, p_s_y_means = torch.split(p_zs_y_means,
                                               self.z_dim - self.joint_dim, -1)
        p_z_y_logvar, p_s_y_logvar = torch.split(p_zs_y_logvar,
                                                 self.z_dim - self.joint_dim,
                                                 -1)

        z_x_sample = self.sample_repar_z(p_z_x_means, p_z_x_logvar)
        z_y_sample = self.sample_repar_z(p_z_y_means, p_z_y_logvar)
        s_y_sample = self.sample_repar_z(p_s_y_means, p_s_y_logvar)

        rest_x = self.dec_x.sample(torch.cat((z_x_sample, s_y_sample), -1))
        rest_y = self.dec_y.sample(torch.cat((z_y_sample, s_y_sample), -1))

        return (rest_x, rest_y)

    def sample(self, y):
        # compute proposal distributions
        p_z_y_means, p_z_y_logvar = torch.split(self.enc_y(y), self.z_dim, -1)

        # sample z
        z_y_sample = self.sample_repar_z(p_z_y_means, p_z_y_logvar)

        sampled_x = self.dec_x.sample(z=z_y_sample)
        return sampled_x

    def sample_y(self, x):
        # compute proposal distributions
        p_z_x_means, p_z_x_logvar = torch.split(self.enc_x(x), self.z_dim, -1)

        # sample z
        z_x_sample = self.sample_repar_z(p_z_x_means, p_z_x_logvar)

        sampled_y = self.dec_x.sample(z=z_x_sample)
        return sampled_y

    def training_step(self, batch, batch_nb):
        # pair of objects
        x, y = batch

        # compute proposal distributions
        p_z_x_means, p_z_x_logvar = torch.split(self.enc_x(x),
                                                self.z_dim - self.joint_dim,
                                                -1)
        p_zs_y_means, p_zs_y_logvar = torch.split(self.enc_y(y), self.z_dim,
                                                  -1)

        p_z_y_means, p_s_y_means = torch.split(p_zs_y_means,
                                               self.z_dim - self.joint_dim, -1)
        p_z_y_logvar, p_s_y_logvar = torch.split(p_zs_y_logvar,
                                                 self.z_dim - self.joint_dim,
                                                 -1)

        # sample z
        z_x_sample = self.sample_repar_z(p_z_x_means, p_z_x_logvar)
        z_y_sample = self.sample_repar_z(p_z_y_means, p_z_y_logvar)
        s_y_sample = self.sample_repar_z(p_s_y_means, p_s_y_logvar)

        # compute kl divergence
        z_x_kl = self.kl_div(p_z_x_means, p_z_x_logvar)
        z_y_kl = self.kl_div(p_z_y_means, p_z_y_logvar)
        s_y_kl = self.kl_div(p_s_y_means, p_s_y_logvar)

        # compute reconstrunction loss
        x_z_logprob = self.dec_x.get_log_prob(
            x, torch.cat((z_x_sample, s_y_sample), -1))
        y_z_logprob = self.dec_y.get_log_prob(
            y, torch.cat((z_y_sample, s_y_sample), -1))

        loss = -(self.loss_rec_lambda_x * x_z_logprob + self.loss_rec_lambda_y
                 * y_z_logprob) + self.beta * (z_x_kl + z_y_kl + s_y_kl)

        return {
            'loss': loss,
            'log': {
                'x_rec': -x_z_logprob,
                'x_kl': z_x_kl,
                'y_rec': -y_z_logprob,
                'y_kl': z_y_kl,
                'common_kl': s_y_kl
            }
        }

    def configure_optimizers(self):
        if self.dataset == 'paired_mnist':
            return torch.optim.Adam(self.parameters(), lr=3e-4)
        elif 'lincs' in self.dataset:
            return torch.optim.Adam(self.parameters(), lr=1e-3)
예제 #5
0
    def __init__(self, dataset='paired_mnist'):
        super(CVAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 8

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.beta = 0.01

            self.enc_y = MnistCNNEncoder(out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(MnistCNNEncoder(out_dim=2 * self.z_dim,
                                                       short_tail=True),
                                       MnistCNNEncoder(out_dim=2 * self.z_dim,
                                                       short_tail=True),
                                       out_dim=2 * self.z_dim)

            self.dec_x_cond = ConditionedDecoder(
                dec=MnistCNNDecoder(in_dim=self.z_dim + self.z_dim),
                cond=MnistCNNEncoder(out_dim=self.z_dim, short_tail=True))
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.beta = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            enc_x = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(enc_x,
                                       ExprDiffEncoder(out_dim=2 * self.z_dim),
                                       out_dim=2 * self.z_dim)

            self.enc_y = ExprDiffEncoder(out_dim=2 * self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            dec_x = FinetunedDecoder(rnn_2, in_dim=2 * self.z_dim)
            self.dec_x_cond = ConditionedDecoder(
                dec=dec_x, cond=ExprDiffEncoder(out_dim=self.z_dim))

        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.beta = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=2 * self.z_dim)

            self.enc_xy = JointEncoder(ExprDiffEncoder(out_dim=2 * self.z_dim),
                                       FinetunedEncoder(rnn_1,
                                                        out_dim=2 *
                                                        self.z_dim),
                                       out_dim=2 * self.z_dim)

            rnn_2 = RNNEncoder(out_dim=88)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))

            self.dec_x_cond = ConditionedDecoder(
                dec=ExprDiffDecoder(in_dim=2 * self.z_dim),
                cond=FinetunedEncoder(rnn_2, out_dim=self.z_dim))
예제 #6
0
class Lat_SAAE(pl.LightningModule):
    def __init__(self, dataset='paired_mnist'):
        super(Lat_SAAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 8

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.loss_latent_lambda = 2

            self.discr_steps = 1
            self.gen_steps = 1

            self.enc_x = MnistCNNEncoder(out_dim=self.z_dim // 2)
            self.enc_y = MnistCNNEncoder(out_dim=self.z_dim // 2)

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim // 2)

            self.discr = FCDiscriminator(in_dim=self.z_dim // 2,
                                         use_sigmoid=False)
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1
            self.loss_latent_lambda = 1

            self.discr_steps = 1
            self.gen_steps = 3

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim // 2)

            self.enc_y = ExprDiffEncoder(out_dim=self.z_dim // 2)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim // 2)

            self.discr = FCDiscriminator(in_dim=self.z_dim // 2,
                                         use_sigmoid=False)

        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2
            self.loss_latent_lambda = 1

            self.discr_steps = 1
            self.gen_steps = 3

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim // 2)

            self.enc_x = ExprDiffEncoder(out_dim=self.z_dim // 2)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim // 2)

            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=self.z_dim // 2,
                                         use_sigmoid=False)

    # ------------------------------------------------------------------------
    #               TRAINING
    def get_latents(self, batch):
        # pair of objects
        x, y = batch

        z_y = self.enc_y(y)
        z_x = torch.randn_like(z_y)

        return torch.cat((z_x, z_y), 1)

    def get_log_p_x_by_y(self, batch):
        return self.dec_x.get_log_prob(batch[0], self.get_latents(batch))

    def restore(self, batch):
        # pair of objects
        x, y = batch

        # compute encoder outputs and split them into joint and exclusive parts
        z_x = self.enc_x(x)
        z_y = self.enc_y(y)

        x_rest = self.dec_x.sample(torch.cat((z_x, z_y), 1))
        y_rest = self.dec_y.sample(z_y)

        return (x_rest, y_rest)

    def sample(self, y):
        # sample z
        z_y = self.enc_y(y)
        z_x = torch.randn_like(z_y)

        sampled_x = self.dec_x.sample(z=torch.cat((z_x, z_y), 1))
        return sampled_x

    def training_step(self, batch, batch_nb, optimizer_i):
        # pair of objects
        x, y = batch

        z_x = self.enc_x(x)
        z_y = self.enc_y(y)

        if optimizer_i == 0:  # GENERATOR LOSS
            rec_x = -self.dec_x.get_log_prob(x, torch.cat(
                (z_x, z_y), 1)).mean()
            rec_y = -self.dec_y.get_log_prob(y, z_y).mean()

            # run discriminators
            discr_outputs = self.discr(z_x)
            latent_loss = nn.BCEWithLogitsLoss()(
                discr_outputs, torch.ones_like(discr_outputs))

            g_loss = (rec_x * self.loss_rec_lambda_x +
                      rec_y * self.loss_rec_lambda_y +
                      latent_loss * self.loss_latent_lambda)

            return {
                'loss': g_loss,
                'log': {
                    'loss_g': g_loss,
                    'x_rec': rec_x,
                    'y_rec': rec_y,
                    'loss_norm': latent_loss
                }
            }
        elif optimizer_i == 1:  # DISCRIMINATOR LOSS
            z_x = z_x.detach()

            # Compare <z_x, s_x, z_y> or <z_x, s_y, z_y> vs N(0, I)
            real_inputs = z_x
            real_dec_out = self.discr(real_inputs)

            fake_inputs = torch.randn_like(real_inputs)
            fake_dec_out = self.discr(fake_inputs)

            probs = torch.cat((real_dec_out, fake_dec_out), 0)
            targets = torch.cat((torch.zeros_like(real_dec_out),
                                 torch.ones_like(fake_dec_out)), 0)

            d_loss = nn.BCEWithLogitsLoss()(probs, targets)

            return {'loss': d_loss, 'log': {'loss_d_normal': d_loss}}

    def configure_optimizers(self):
        gen_params = torch.nn.ModuleList(
            [self.enc_x, self.dec_x, self.enc_y, self.dec_y])
        discr_params = torch.nn.ModuleList([self.discr])

        gen_optim = torch.optim.Adam(gen_params.parameters(),
                                     lr=3e-4,
                                     betas=(0.5, 0.9))
        discr_optim = torch.optim.Adam(discr_params.parameters(),
                                       lr=3e-4,
                                       betas=(0.5, 0.9))

        discriminator_sched = StepLR(discr_optim, step_size=5000, gamma=0.5)

        return [gen_optim, discr_optim], [discriminator_sched]

    def zero_grad(self):
        self.enc_x.zero_grad()
        self.dec_x.zero_grad()
        self.enc_y.zero_grad()
        self.dec_y.zero_grad()
        self.discr.zero_grad()

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i,
                       optimizer_closure):
        discr_step = (batch_nb % (self.discr_steps + self.gen_steps)) < \
                     self.discr_steps

        gen_step = (not discr_step)

        if optimizer_i == 0:
            if gen_step:
                optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()

        if optimizer_i == 1:
            if discr_step:
                optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()

        if optimizer_i > 1:
            optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()