def __init__(self):
            super().__init__(dataset='lincs_rnn')

            rnn = RNNEncoder(out_dim=88)
            rnn.load_state_dict(torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1'))
            self.mine_enc = FinetunedEncoder(rnn, out_dim=self.z_dim)
            self.mine_fc = FCDiscriminator(in_dim=2 * self.z_dim)
Example #2
0
        def __init__(self):
            super().__init__(dataset='paired_mnist')

            self.mine_enc = MnistCNNEncoder(out_dim=self.z_dim)
            self.mine_fc = FCDiscriminator(in_dim=2 * self.z_dim)

            # classifier to measure accuracy of generation
            self.mnist_clf = MNISTClassifier()
            self.mnist_clf.load_state_dict(
                torch.load('../saved_models/mnist_clf.ckpt'))
            self.mnist_clf.eval()
            self.mnist_clf.freeze()
Example #3
0
    def __init__(self, dataset='paired_mnist'):
        super(BiAAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.loss_normal_lambda = 0.3
            self.loss_indep_lambda_x = 1
            self.loss_indep_lambda_y = 1
            self.loss_mse_lambda = 0.1

            self.discr_steps = 1
            self.gen_steps = 1

            self.enc_x = MnistCNNEncoder(out_dim=self.z_dim)
            self.enc_y = MnistCNNEncoder(out_dim=self.z_dim)

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_y = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_x = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
Example #4
0
class BiAAE(pl.LightningModule):
    def __init__(self, dataset='paired_mnist'):
        super(BiAAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 16
            self.joint_dim = 4

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.loss_normal_lambda = 0.3
            self.loss_indep_lambda_x = 1
            self.loss_indep_lambda_y = 1
            self.loss_mse_lambda = 0.1

            self.discr_steps = 1
            self.gen_steps = 1

            self.enc_x = MnistCNNEncoder(out_dim=self.z_dim)
            self.enc_y = MnistCNNEncoder(out_dim=self.z_dim)

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_y = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)
        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20
            self.joint_dim = 10

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2

            self.loss_normal_lambda = 0.5
            self.loss_indep_lambda_x = 0.5
            self.loss_indep_lambda_y = 0.5
            self.loss_mse_lambda = 0.5

            self.discr_steps = 3
            self.gen_steps = 1

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim)
            self.enc_x = ExprDiffEncoder(out_dim=self.z_dim)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=2 * self.z_dim -
                                         self.joint_dim,
                                         use_sigmoid=False)

            self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

            self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim -
                                                 self.joint_dim,
                                                 use_sigmoid=False)

    # ------------------------------------------------------------------------
    #               TRAINING
    def get_latents(self, batch):
        # pair of objects
        x, y = batch

        z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1)
        z_x = torch.randn_like(z_y)

        return torch.cat((z_x, s_y), 1)

    def get_log_p_x_by_y(self, batch):
        return self.dec_x.get_log_prob(batch[0], self.get_latents(batch))

    def restore(self, batch):
        # pair of objects
        x, y = batch

        # compute encoder outputs and split them into joint and exclusive parts
        z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1)
        z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1)

        x_rest = self.dec_x.sample(torch.cat((z_x, s_x), 1))
        y_rest = self.dec_y.sample(torch.cat((z_y, s_y), 1))

        return (x_rest, y_rest)

    def sample(self, y):
        # sample z
        z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1)
        z_x = torch.randn_like(z_y)

        sampled_x = self.dec_x.sample(z=torch.cat((z_x, s_y), 1))
        return sampled_x

    def sample_y(self, x):
        # sample y
        z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1)
        z_y = torch.randn_like(z_x)

        sampled_x = self.dec_y.sample(z=torch.cat((z_y, s_x), 1))
        return sampled_x

    def training_step(self, batch, batch_nb, optimizer_i):
        # pair of objects
        x, y = batch

        # compute encoder outputs and split them into joint and exclusive parts
        z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1)
        z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1)

        if optimizer_i == 0:  # GENERATOR LOSS
            # Reconstruction losses
            loss_x_rec = -(self.dec_x.get_log_prob(x, torch.cat(
                (z_x, s_y), 1)).mean() +
                           self.dec_x.get_log_prob(x, torch.cat(
                               (z_x, s_x), 1)).mean()) / 2

            loss_y_rec = -(self.dec_y.get_log_prob(y, torch.cat(
                (z_y, s_y), 1)).mean() +
                           self.dec_y.get_log_prob(y, torch.cat(
                               (z_y, s_x), 1)).mean()) / 2

            # compute mse between common parts
            loss_mse = torch.norm(s_x - s_y, p=2, dim=-1).mean()

            # run discriminators
            discr_outputs = self.discr(
                torch.cat((torch.cat(
                    (z_x, s_x, z_y), dim=-1), torch.cat(
                        (z_x, s_y, z_y), dim=-1)),
                          dim=0))
            loss_norm = nn.BCEWithLogitsLoss()(discr_outputs,
                                               torch.ones_like(discr_outputs))

            discr_outputs_x = self.discr_indep_x(
                torch.cat((z_x, s_y.detach(), z_y.detach()), dim=-1))
            loss_indep_x = nn.BCEWithLogitsLoss()(
                discr_outputs_x, torch.ones_like(discr_outputs_x))

            discr_outputs_y = self.discr_indep_y(
                torch.cat((z_x.detach(), s_x.detach(), z_y), dim=-1))
            loss_indep_y = nn.BCEWithLogitsLoss()(
                discr_outputs_y, torch.ones_like(discr_outputs_y))

            g_loss = (loss_x_rec * self.loss_rec_lambda_x +
                      loss_y_rec * self.loss_rec_lambda_y +
                      loss_norm * self.loss_normal_lambda +
                      loss_indep_x * self.loss_indep_lambda_x +
                      loss_indep_y * self.loss_indep_lambda_y +
                      loss_mse * self.loss_mse_lambda)

            return {
                'loss': g_loss,
                'log': {
                    'loss_g': g_loss,
                    'x_rec': loss_x_rec,
                    'y_rec': loss_y_rec,
                    'loss_norm': loss_norm,
                    'loss_indep_x': loss_indep_x,
                    'loss_indep_y': loss_indep_y,
                    'loss_mse': loss_mse
                }
            }

        elif optimizer_i == 1:  # DISCRIMINATOR LOSS
            z_x = z_x.detach()
            s_x = s_x.detach()
            s_y = s_y.detach()
            z_y = z_y.detach()

            # normal noise loss
            real_inputs_1 = torch.cat((torch.cat(
                (z_x, s_x, z_y), dim=-1), torch.cat((z_x, s_y, z_y), dim=-1)),
                                      dim=0)
            real_dec_out_1 = self.discr(real_inputs_1)

            fake_inputs_1 = torch.randn_like(real_inputs_1)
            fake_dec_out_1 = self.discr(fake_inputs_1)

            probs_1 = torch.cat((real_dec_out_1, fake_dec_out_1), 0)
            targets_1 = torch.cat((torch.zeros_like(real_dec_out_1),
                                   torch.ones_like(fake_dec_out_1)), 0)

            d_loss_normal = nn.BCEWithLogitsLoss()(probs_1, targets_1)

            # independece loss
            real_inputs_2 = torch.cat((z_x, s_y, z_y), dim=-1)
            real_dec_out_2 = self.discr_indep_x(real_inputs_2)

            real_input_shuffled_2 = torch.cat(
                (z_x[np.random.permutation(z_x.shape[0])], s_y, z_y), dim=-1)
            fake_dec_out_2 = self.discr_indep_x(real_input_shuffled_2)

            probs_2 = torch.cat((real_dec_out_2, fake_dec_out_2), 0)
            targets_2 = torch.cat((torch.zeros_like(real_dec_out_2),
                                   torch.ones_like(fake_dec_out_2)), 0)

            d_loss_indep_x = nn.BCEWithLogitsLoss()(probs_2, targets_2)

            #-----------------------

            real_inputs_3 = torch.cat((z_x, s_x, z_y), dim=-1)
            real_dec_out_3 = self.discr_indep_y(real_inputs_3)

            real_input_shuffled_3 = torch.cat(
                (z_x, s_x, z_y[np.random.permutation(z_y.shape[0])]), dim=-1)
            fake_dec_out_3 = self.discr_indep_y(real_input_shuffled_3)

            probs_3 = torch.cat((real_dec_out_3, fake_dec_out_3), 0)
            targets_3 = torch.cat((torch.zeros_like(real_dec_out_3),
                                   torch.ones_like(fake_dec_out_3)), 0)

            d_loss_indep_y = nn.BCEWithLogitsLoss()(probs_3, targets_3)

            return {
                'loss': d_loss_normal + d_loss_indep_x + d_loss_indep_y,
                'log': {
                    'loss_d_normal': d_loss_normal,
                    'loss_d_indep_x': d_loss_indep_x,
                    'loss_d_indep_y': d_loss_indep_y
                }
            }

    def configure_optimizers(self):
        gen_params = torch.nn.ModuleList(
            [self.enc_x, self.dec_x, self.enc_y, self.dec_y])
        discr_params = torch.nn.ModuleList(
            [self.discr_indep_x, self.discr_indep_y, self.discr])

        gen_optim = torch.optim.Adam(gen_params.parameters(),
                                     lr=3e-4,
                                     betas=(0.5, 0.9))
        discr_optim = torch.optim.Adam(discr_params.parameters(),
                                       lr=3e-4,
                                       betas=(0.5, 0.9))

        discriminator_sched = StepLR(discr_optim, step_size=1000, gamma=0.99)

        return [gen_optim, discr_optim], [discriminator_sched]

    def zero_grad(self):
        self.enc_x.zero_grad()
        self.dec_x.zero_grad()
        self.enc_y.zero_grad()
        self.dec_y.zero_grad()

        self.discr.zero_grad()
        self.discr_indep_x.zero_grad()
        self.discr_indep_y.zero_grad()

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i,
                       optimizer_closure):
        discr_step = (batch_nb % (self.discr_steps + self.gen_steps)) < \
                     self.discr_steps

        gen_step = (not discr_step)

        if optimizer_i == 0:
            if gen_step:
                optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()

        if optimizer_i == 1:
            if discr_step:
                optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()

        if optimizer_i > 1:
            optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()
        def __init__(self):
            super().__init__(dataset='lincs_rnn_reverse')

            self.mine_enc = ExprDiffEncoder(out_dim=self.z_dim)
            self.mine_fc = FCDiscriminator(in_dim=2 * self.z_dim)
Example #6
0
class Lat_SAAE(pl.LightningModule):
    def __init__(self, dataset='paired_mnist'):
        super(Lat_SAAE, self).__init__()
        self.dataset = dataset

        if self.dataset == 'paired_mnist':
            self.z_dim = 8

            self.loss_rec_lambda_x = 10
            self.loss_rec_lambda_y = 10

            self.loss_latent_lambda = 2

            self.discr_steps = 1
            self.gen_steps = 1

            self.enc_x = MnistCNNEncoder(out_dim=self.z_dim // 2)
            self.enc_y = MnistCNNEncoder(out_dim=self.z_dim // 2)

            self.dec_x = MnistCNNDecoder(in_dim=self.z_dim)
            self.dec_y = MnistCNNDecoder(in_dim=self.z_dim // 2)

            self.discr = FCDiscriminator(in_dim=self.z_dim // 2,
                                         use_sigmoid=False)
        elif self.dataset == 'lincs_rnn':
            self.z_dim = 20

            self.loss_rec_lambda_x = 5
            self.loss_rec_lambda_y = 1
            self.loss_latent_lambda = 1

            self.discr_steps = 1
            self.gen_steps = 3

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim // 2)

            self.enc_y = ExprDiffEncoder(out_dim=self.z_dim // 2)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim)

            self.dec_y = ExprDiffDecoder(in_dim=self.z_dim // 2)

            self.discr = FCDiscriminator(in_dim=self.z_dim // 2,
                                         use_sigmoid=False)

        elif self.dataset == 'lincs_rnn_reverse':
            self.z_dim = 20

            self.loss_rec_lambda_x = 1
            self.loss_rec_lambda_y = 0.2
            self.loss_latent_lambda = 1

            self.discr_steps = 1
            self.gen_steps = 3

            rnn_1 = RNNEncoder(out_dim=88)
            rnn_1.load_state_dict(
                torch.load('../saved_models/rnn_enc.ckpt',
                           map_location='cuda:1'))
            self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim // 2)

            self.enc_x = ExprDiffEncoder(out_dim=self.z_dim // 2)

            rnn_2 = RNNDecoder(in_dim=44)
            rnn_2.load_state_dict(
                torch.load('../saved_models/rnn_dec.ckpt',
                           map_location='cuda:1'))
            self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim // 2)

            self.dec_x = ExprDiffDecoder(in_dim=self.z_dim)

            self.discr = FCDiscriminator(in_dim=self.z_dim // 2,
                                         use_sigmoid=False)

    # ------------------------------------------------------------------------
    #               TRAINING
    def get_latents(self, batch):
        # pair of objects
        x, y = batch

        z_y = self.enc_y(y)
        z_x = torch.randn_like(z_y)

        return torch.cat((z_x, z_y), 1)

    def get_log_p_x_by_y(self, batch):
        return self.dec_x.get_log_prob(batch[0], self.get_latents(batch))

    def restore(self, batch):
        # pair of objects
        x, y = batch

        # compute encoder outputs and split them into joint and exclusive parts
        z_x = self.enc_x(x)
        z_y = self.enc_y(y)

        x_rest = self.dec_x.sample(torch.cat((z_x, z_y), 1))
        y_rest = self.dec_y.sample(z_y)

        return (x_rest, y_rest)

    def sample(self, y):
        # sample z
        z_y = self.enc_y(y)
        z_x = torch.randn_like(z_y)

        sampled_x = self.dec_x.sample(z=torch.cat((z_x, z_y), 1))
        return sampled_x

    def training_step(self, batch, batch_nb, optimizer_i):
        # pair of objects
        x, y = batch

        z_x = self.enc_x(x)
        z_y = self.enc_y(y)

        if optimizer_i == 0:  # GENERATOR LOSS
            rec_x = -self.dec_x.get_log_prob(x, torch.cat(
                (z_x, z_y), 1)).mean()
            rec_y = -self.dec_y.get_log_prob(y, z_y).mean()

            # run discriminators
            discr_outputs = self.discr(z_x)
            latent_loss = nn.BCEWithLogitsLoss()(
                discr_outputs, torch.ones_like(discr_outputs))

            g_loss = (rec_x * self.loss_rec_lambda_x +
                      rec_y * self.loss_rec_lambda_y +
                      latent_loss * self.loss_latent_lambda)

            return {
                'loss': g_loss,
                'log': {
                    'loss_g': g_loss,
                    'x_rec': rec_x,
                    'y_rec': rec_y,
                    'loss_norm': latent_loss
                }
            }
        elif optimizer_i == 1:  # DISCRIMINATOR LOSS
            z_x = z_x.detach()

            # Compare <z_x, s_x, z_y> or <z_x, s_y, z_y> vs N(0, I)
            real_inputs = z_x
            real_dec_out = self.discr(real_inputs)

            fake_inputs = torch.randn_like(real_inputs)
            fake_dec_out = self.discr(fake_inputs)

            probs = torch.cat((real_dec_out, fake_dec_out), 0)
            targets = torch.cat((torch.zeros_like(real_dec_out),
                                 torch.ones_like(fake_dec_out)), 0)

            d_loss = nn.BCEWithLogitsLoss()(probs, targets)

            return {'loss': d_loss, 'log': {'loss_d_normal': d_loss}}

    def configure_optimizers(self):
        gen_params = torch.nn.ModuleList(
            [self.enc_x, self.dec_x, self.enc_y, self.dec_y])
        discr_params = torch.nn.ModuleList([self.discr])

        gen_optim = torch.optim.Adam(gen_params.parameters(),
                                     lr=3e-4,
                                     betas=(0.5, 0.9))
        discr_optim = torch.optim.Adam(discr_params.parameters(),
                                       lr=3e-4,
                                       betas=(0.5, 0.9))

        discriminator_sched = StepLR(discr_optim, step_size=5000, gamma=0.5)

        return [gen_optim, discr_optim], [discriminator_sched]

    def zero_grad(self):
        self.enc_x.zero_grad()
        self.dec_x.zero_grad()
        self.enc_y.zero_grad()
        self.dec_y.zero_grad()
        self.discr.zero_grad()

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i,
                       optimizer_closure):
        discr_step = (batch_nb % (self.discr_steps + self.gen_steps)) < \
                     self.discr_steps

        gen_step = (not discr_step)

        if optimizer_i == 0:
            if gen_step:
                optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()

        if optimizer_i == 1:
            if discr_step:
                optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()

        if optimizer_i > 1:
            optimizer.step()
            optimizer.zero_grad()
            self.zero_grad()
Example #7
0
def do_epoch(args, mode: str, net: Any, device: Any, loader: DataLoader, epc: int,
             loss_fns: List[Callable], loss_weights: List[float],loss_fns_source: List[Callable],
             loss_weights_source: List[float], new_w:int, num_steps:int, C: int, savedir: str = "",
             optimizer: Any = None, target_loader: Any = None, lambda_adv_target:float =0.001) -> Tuple[List,List,List,List]:

    assert mode in ["train", "val"]
    L: int = len(loss_fns)

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration, total_images = len(loader), len(loader.dataset)

    # losses metrics
    loss_seg_log = np.zeros(total_images)
    loss_cons_log = np.zeros(total_images)
    loss_inf_log = np.zeros(total_images)
    loss_adv_log = np.zeros(total_images)
    loss_D_log = np.zeros(total_images)

    # source metrics
    dices_log_s = np.zeros((total_images, C))
    posim_log_s = np.zeros(total_images)
    haussdorf_log_s = np.zeros((total_images, C))

    # target metrics
    dices_log_t = np.zeros((total_images, C))
    dices_baseline_log_t = np.zeros((total_images, C))
    posim_log_t = np.zeros(total_images)
    haussdorf_log_t = np.zeros((total_images, C))

    cudnn.benchmark = True
    model_D = FCDiscriminator(num_classes=C)
    model_D.train()
    model_D.to(device)
    optimizer_D = torch.optim.Adam(model_D.parameters(), lr=args.l_rate_D, betas=(0.9, 0.99))
    tq_iter = tqdm_(enumerate(zip(loader, target_loader)), total=total_iteration, desc=desc)
    done: int = 0
    dice_3d_s = 0
    dice_3d_sd_s = 0
    dice_3d_t = 0
    dice_3d_sd_t = 0
    baseline_target_vec = [0,0]
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for j, (source_data, target_data) in tq_iter:
            source_data[1:] = [e.to(device) for e in source_data[1:]]  # Move all tensors to device
            filenames_source, source_image, source_gt = source_data
            target_data[1:] = [e.to(device) for e in target_data[1:]]  # Move all tensors to device
            filenames_target, target_image, target_gt = target_data[:3]
            labels = target_data[3:3+L]
            bounds = target_data[3+L:]
            assert len(labels) == len(bounds)
            B = len(target_image)
            #print("source: %s , target: %s" % (filenames_source, filenames_target))

            source_probs, target_probs, loss, loss_seg, loss_adv, loss_inf, loss_cons, loss_D = for_back_step_comb(optimizer, mode, source_image, target_image, source_gt, labels,
                                                                            net, loss_fns, loss_weights,loss_fns_source, loss_weights_source, new_w, device, bounds,
                                                                             model_D, optimizer_D, lambda_adv_target)

            #compute metrics for current batch

            if new_w > 0:
                source_gt = resize(source_gt, new_w)
                target_gt = resize(target_gt, new_w)
                labels[0] = resize(labels[0], new_w)

            dices_s, _, posim_s, haussdorf_s = compute_metrics(source_probs, source_gt, source_gt)
            dices_t, dices_baseline_t, posim_t, haussdorf_t = compute_metrics(target_probs, target_gt, labels[0])

            do_save_images(target_probs, savedir, filenames_target, mode, epc)
            do_save_images(source_probs, savedir, filenames_source, "_".join(("source", mode)), epc)

            # keep metrics in ndarrays
            sm_slice = slice(done, done + B)
            loss_seg_log[sm_slice] = loss_seg
            loss_cons_log[sm_slice] = loss_cons
            loss_adv_log[sm_slice] = loss_adv
            loss_inf_log[sm_slice] = loss_inf
            loss_D_log[sm_slice] = loss_D

            dices_log_s[sm_slice, ...] = dices_s
            haussdorf_log_s[sm_slice] = haussdorf_s
            posim_log_s[sm_slice] = posim_s

            dices_log_t[sm_slice, ...] = dices_t
            dices_baseline_log_t[sm_slice, ...] = dices_baseline_t
            haussdorf_log_t[sm_slice] = haussdorf_t
            posim_log_t[sm_slice] = posim_t

            done +=B

    # calculate mean of metrics on all images
    loss_seg_log = loss_seg_log.mean()
    loss_adv_log = loss_adv_log.mean()
    loss_cons_log = loss_cons_log.mean()
    loss_inf_log = loss_inf_log.mean()
    loss_D_log = loss_inf_log.mean()

    # first select positive and negative images
    dice_posim_log_s = np.compress(posim_log_s,[dices_log_s[:,1]]).mean()
    dice_negim_log_s = np.compress(1-posim_log_s, [dices_log_s[:,1]]).mean()

    dice_posim_log_t = np.compress(posim_log_t, [dices_log_t[:,1]]).mean()
    dice_negim_log_t = np.compress(1-posim_log_t, [dices_log_t[:,1]]).mean()

    # mean on the source images
    dices_log_s = dices_log_s[:, -1].mean()
    haussdorf_log_s = haussdorf_log_s[:, -1].mean()

    # mean on the target images
    dices_log_t = dices_log_t[:, -1].mean()
    haussdorf_log_t = haussdorf_log_t[:, -1].mean()

    # dice3D gives back the 3d dice mean on images
    if not args.debug:
        dice_3d_s, dice_3d_sd_s = dice3d(args.workdir,   f"iter{epc:03d}", "source_"+mode, "Subj_\\d+_", args.dataset+mode+'/GT')
        dice_3d_t, dice_3d_sd_t = dice3d(args.workdir,   f"iter{epc:03d}", mode, "Subj_\\d+_", args.target_dataset+mode+'/GT')
        if epc == 0:
            dice_3d_baseline, dice_3d_sd_baseline = dice3d(args.target_dataset, mode, 'Wat_on_Inn_n', "Subj_\\d+_",args.target_dataset+mode+'/GT')
            baseline_target_vec = [dice_3d_baseline, dice_3d_sd_baseline]

    stat_dict = {"dice 3D source": dice_3d_s,
                 "dice 3D target": dice_3d_t}
    nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()}

    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    # Keep metrics in vectors
    losses_vec = [loss_seg_log, loss_adv_log,loss_inf_log, loss_cons_log , loss_D_log]
    source_vec = [dices_log_s, dice_posim_log_s, dice_negim_log_s, dice_3d_s, dice_3d_sd_s, haussdorf_log_s]
    target_vec = [dices_log_t, dice_posim_log_t, dice_negim_log_t, dice_3d_t, dice_3d_sd_t, haussdorf_log_t]

    return losses_vec, source_vec, target_vec, baseline_target_vec