class BiAAE(pl.LightningModule): def __init__(self, dataset='paired_mnist'): super(BiAAE, self).__init__() self.dataset = dataset if self.dataset == 'paired_mnist': self.z_dim = 16 self.joint_dim = 4 self.loss_rec_lambda_x = 10 self.loss_rec_lambda_y = 10 self.loss_normal_lambda = 0.3 self.loss_indep_lambda_x = 1 self.loss_indep_lambda_y = 1 self.loss_mse_lambda = 0.1 self.discr_steps = 1 self.gen_steps = 1 self.enc_x = MnistCNNEncoder(out_dim=self.z_dim) self.enc_y = MnistCNNEncoder(out_dim=self.z_dim) self.dec_x = MnistCNNDecoder(in_dim=self.z_dim) self.dec_y = MnistCNNDecoder(in_dim=self.z_dim) self.discr = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) elif self.dataset == 'lincs_rnn': self.z_dim = 20 self.joint_dim = 10 self.loss_rec_lambda_x = 5 self.loss_rec_lambda_y = 1 self.loss_normal_lambda = 0.5 self.loss_indep_lambda_x = 0.5 self.loss_indep_lambda_y = 0.5 self.loss_mse_lambda = 0.5 self.discr_steps = 3 self.gen_steps = 1 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim) self.enc_y = ExprDiffEncoder(out_dim=self.z_dim) rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim) self.dec_y = ExprDiffDecoder(in_dim=self.z_dim) self.discr = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) elif self.dataset == 'lincs_rnn_reverse': self.z_dim = 20 self.joint_dim = 10 self.loss_rec_lambda_x = 1 self.loss_rec_lambda_y = 0.2 self.loss_normal_lambda = 0.5 self.loss_indep_lambda_x = 0.5 self.loss_indep_lambda_y = 0.5 self.loss_mse_lambda = 0.5 self.discr_steps = 3 self.gen_steps = 1 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim) self.enc_x = ExprDiffEncoder(out_dim=self.z_dim) rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim) self.dec_x = ExprDiffDecoder(in_dim=self.z_dim) self.discr = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_x = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) self.discr_indep_y = FCDiscriminator(in_dim=2 * self.z_dim - self.joint_dim, use_sigmoid=False) # ------------------------------------------------------------------------ # TRAINING def get_latents(self, batch): # pair of objects x, y = batch z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1) z_x = torch.randn_like(z_y) return torch.cat((z_x, s_y), 1) def get_log_p_x_by_y(self, batch): return self.dec_x.get_log_prob(batch[0], self.get_latents(batch)) def restore(self, batch): # pair of objects x, y = batch # compute encoder outputs and split them into joint and exclusive parts z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1) z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1) x_rest = self.dec_x.sample(torch.cat((z_x, s_x), 1)) y_rest = self.dec_y.sample(torch.cat((z_y, s_y), 1)) return (x_rest, y_rest) def sample(self, y): # sample z z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1) z_x = torch.randn_like(z_y) sampled_x = self.dec_x.sample(z=torch.cat((z_x, s_y), 1)) return sampled_x def sample_y(self, x): # sample y z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1) z_y = torch.randn_like(z_x) sampled_x = self.dec_y.sample(z=torch.cat((z_y, s_x), 1)) return sampled_x def training_step(self, batch, batch_nb, optimizer_i): # pair of objects x, y = batch # compute encoder outputs and split them into joint and exclusive parts z_x, s_x = torch.split(self.enc_x(x), self.z_dim - self.joint_dim, -1) z_y, s_y = torch.split(self.enc_y(y), self.z_dim - self.joint_dim, -1) if optimizer_i == 0: # GENERATOR LOSS # Reconstruction losses loss_x_rec = -(self.dec_x.get_log_prob(x, torch.cat( (z_x, s_y), 1)).mean() + self.dec_x.get_log_prob(x, torch.cat( (z_x, s_x), 1)).mean()) / 2 loss_y_rec = -(self.dec_y.get_log_prob(y, torch.cat( (z_y, s_y), 1)).mean() + self.dec_y.get_log_prob(y, torch.cat( (z_y, s_x), 1)).mean()) / 2 # compute mse between common parts loss_mse = torch.norm(s_x - s_y, p=2, dim=-1).mean() # run discriminators discr_outputs = self.discr( torch.cat((torch.cat( (z_x, s_x, z_y), dim=-1), torch.cat( (z_x, s_y, z_y), dim=-1)), dim=0)) loss_norm = nn.BCEWithLogitsLoss()(discr_outputs, torch.ones_like(discr_outputs)) discr_outputs_x = self.discr_indep_x( torch.cat((z_x, s_y.detach(), z_y.detach()), dim=-1)) loss_indep_x = nn.BCEWithLogitsLoss()( discr_outputs_x, torch.ones_like(discr_outputs_x)) discr_outputs_y = self.discr_indep_y( torch.cat((z_x.detach(), s_x.detach(), z_y), dim=-1)) loss_indep_y = nn.BCEWithLogitsLoss()( discr_outputs_y, torch.ones_like(discr_outputs_y)) g_loss = (loss_x_rec * self.loss_rec_lambda_x + loss_y_rec * self.loss_rec_lambda_y + loss_norm * self.loss_normal_lambda + loss_indep_x * self.loss_indep_lambda_x + loss_indep_y * self.loss_indep_lambda_y + loss_mse * self.loss_mse_lambda) return { 'loss': g_loss, 'log': { 'loss_g': g_loss, 'x_rec': loss_x_rec, 'y_rec': loss_y_rec, 'loss_norm': loss_norm, 'loss_indep_x': loss_indep_x, 'loss_indep_y': loss_indep_y, 'loss_mse': loss_mse } } elif optimizer_i == 1: # DISCRIMINATOR LOSS z_x = z_x.detach() s_x = s_x.detach() s_y = s_y.detach() z_y = z_y.detach() # normal noise loss real_inputs_1 = torch.cat((torch.cat( (z_x, s_x, z_y), dim=-1), torch.cat((z_x, s_y, z_y), dim=-1)), dim=0) real_dec_out_1 = self.discr(real_inputs_1) fake_inputs_1 = torch.randn_like(real_inputs_1) fake_dec_out_1 = self.discr(fake_inputs_1) probs_1 = torch.cat((real_dec_out_1, fake_dec_out_1), 0) targets_1 = torch.cat((torch.zeros_like(real_dec_out_1), torch.ones_like(fake_dec_out_1)), 0) d_loss_normal = nn.BCEWithLogitsLoss()(probs_1, targets_1) # independece loss real_inputs_2 = torch.cat((z_x, s_y, z_y), dim=-1) real_dec_out_2 = self.discr_indep_x(real_inputs_2) real_input_shuffled_2 = torch.cat( (z_x[np.random.permutation(z_x.shape[0])], s_y, z_y), dim=-1) fake_dec_out_2 = self.discr_indep_x(real_input_shuffled_2) probs_2 = torch.cat((real_dec_out_2, fake_dec_out_2), 0) targets_2 = torch.cat((torch.zeros_like(real_dec_out_2), torch.ones_like(fake_dec_out_2)), 0) d_loss_indep_x = nn.BCEWithLogitsLoss()(probs_2, targets_2) #----------------------- real_inputs_3 = torch.cat((z_x, s_x, z_y), dim=-1) real_dec_out_3 = self.discr_indep_y(real_inputs_3) real_input_shuffled_3 = torch.cat( (z_x, s_x, z_y[np.random.permutation(z_y.shape[0])]), dim=-1) fake_dec_out_3 = self.discr_indep_y(real_input_shuffled_3) probs_3 = torch.cat((real_dec_out_3, fake_dec_out_3), 0) targets_3 = torch.cat((torch.zeros_like(real_dec_out_3), torch.ones_like(fake_dec_out_3)), 0) d_loss_indep_y = nn.BCEWithLogitsLoss()(probs_3, targets_3) return { 'loss': d_loss_normal + d_loss_indep_x + d_loss_indep_y, 'log': { 'loss_d_normal': d_loss_normal, 'loss_d_indep_x': d_loss_indep_x, 'loss_d_indep_y': d_loss_indep_y } } def configure_optimizers(self): gen_params = torch.nn.ModuleList( [self.enc_x, self.dec_x, self.enc_y, self.dec_y]) discr_params = torch.nn.ModuleList( [self.discr_indep_x, self.discr_indep_y, self.discr]) gen_optim = torch.optim.Adam(gen_params.parameters(), lr=3e-4, betas=(0.5, 0.9)) discr_optim = torch.optim.Adam(discr_params.parameters(), lr=3e-4, betas=(0.5, 0.9)) discriminator_sched = StepLR(discr_optim, step_size=1000, gamma=0.99) return [gen_optim, discr_optim], [discriminator_sched] def zero_grad(self): self.enc_x.zero_grad() self.dec_x.zero_grad() self.enc_y.zero_grad() self.dec_y.zero_grad() self.discr.zero_grad() self.discr_indep_x.zero_grad() self.discr_indep_y.zero_grad() def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, optimizer_closure): discr_step = (batch_nb % (self.discr_steps + self.gen_steps)) < \ self.discr_steps gen_step = (not discr_step) if optimizer_i == 0: if gen_step: optimizer.step() optimizer.zero_grad() self.zero_grad() if optimizer_i == 1: if discr_step: optimizer.step() optimizer.zero_grad() self.zero_grad() if optimizer_i > 1: optimizer.step() optimizer.zero_grad() self.zero_grad()
class Lat_SAAE(pl.LightningModule): def __init__(self, dataset='paired_mnist'): super(Lat_SAAE, self).__init__() self.dataset = dataset if self.dataset == 'paired_mnist': self.z_dim = 8 self.loss_rec_lambda_x = 10 self.loss_rec_lambda_y = 10 self.loss_latent_lambda = 2 self.discr_steps = 1 self.gen_steps = 1 self.enc_x = MnistCNNEncoder(out_dim=self.z_dim // 2) self.enc_y = MnistCNNEncoder(out_dim=self.z_dim // 2) self.dec_x = MnistCNNDecoder(in_dim=self.z_dim) self.dec_y = MnistCNNDecoder(in_dim=self.z_dim // 2) self.discr = FCDiscriminator(in_dim=self.z_dim // 2, use_sigmoid=False) elif self.dataset == 'lincs_rnn': self.z_dim = 20 self.loss_rec_lambda_x = 5 self.loss_rec_lambda_y = 1 self.loss_latent_lambda = 1 self.discr_steps = 1 self.gen_steps = 3 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_x = FinetunedEncoder(rnn_1, out_dim=self.z_dim // 2) self.enc_y = ExprDiffEncoder(out_dim=self.z_dim // 2) rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) self.dec_x = FinetunedDecoder(rnn_2, in_dim=self.z_dim) self.dec_y = ExprDiffDecoder(in_dim=self.z_dim // 2) self.discr = FCDiscriminator(in_dim=self.z_dim // 2, use_sigmoid=False) elif self.dataset == 'lincs_rnn_reverse': self.z_dim = 20 self.loss_rec_lambda_x = 1 self.loss_rec_lambda_y = 0.2 self.loss_latent_lambda = 1 self.discr_steps = 1 self.gen_steps = 3 rnn_1 = RNNEncoder(out_dim=88) rnn_1.load_state_dict( torch.load('../saved_models/rnn_enc.ckpt', map_location='cuda:1')) self.enc_y = FinetunedEncoder(rnn_1, out_dim=self.z_dim // 2) self.enc_x = ExprDiffEncoder(out_dim=self.z_dim // 2) rnn_2 = RNNDecoder(in_dim=44) rnn_2.load_state_dict( torch.load('../saved_models/rnn_dec.ckpt', map_location='cuda:1')) self.dec_y = FinetunedDecoder(rnn_2, in_dim=self.z_dim // 2) self.dec_x = ExprDiffDecoder(in_dim=self.z_dim) self.discr = FCDiscriminator(in_dim=self.z_dim // 2, use_sigmoid=False) # ------------------------------------------------------------------------ # TRAINING def get_latents(self, batch): # pair of objects x, y = batch z_y = self.enc_y(y) z_x = torch.randn_like(z_y) return torch.cat((z_x, z_y), 1) def get_log_p_x_by_y(self, batch): return self.dec_x.get_log_prob(batch[0], self.get_latents(batch)) def restore(self, batch): # pair of objects x, y = batch # compute encoder outputs and split them into joint and exclusive parts z_x = self.enc_x(x) z_y = self.enc_y(y) x_rest = self.dec_x.sample(torch.cat((z_x, z_y), 1)) y_rest = self.dec_y.sample(z_y) return (x_rest, y_rest) def sample(self, y): # sample z z_y = self.enc_y(y) z_x = torch.randn_like(z_y) sampled_x = self.dec_x.sample(z=torch.cat((z_x, z_y), 1)) return sampled_x def training_step(self, batch, batch_nb, optimizer_i): # pair of objects x, y = batch z_x = self.enc_x(x) z_y = self.enc_y(y) if optimizer_i == 0: # GENERATOR LOSS rec_x = -self.dec_x.get_log_prob(x, torch.cat( (z_x, z_y), 1)).mean() rec_y = -self.dec_y.get_log_prob(y, z_y).mean() # run discriminators discr_outputs = self.discr(z_x) latent_loss = nn.BCEWithLogitsLoss()( discr_outputs, torch.ones_like(discr_outputs)) g_loss = (rec_x * self.loss_rec_lambda_x + rec_y * self.loss_rec_lambda_y + latent_loss * self.loss_latent_lambda) return { 'loss': g_loss, 'log': { 'loss_g': g_loss, 'x_rec': rec_x, 'y_rec': rec_y, 'loss_norm': latent_loss } } elif optimizer_i == 1: # DISCRIMINATOR LOSS z_x = z_x.detach() # Compare <z_x, s_x, z_y> or <z_x, s_y, z_y> vs N(0, I) real_inputs = z_x real_dec_out = self.discr(real_inputs) fake_inputs = torch.randn_like(real_inputs) fake_dec_out = self.discr(fake_inputs) probs = torch.cat((real_dec_out, fake_dec_out), 0) targets = torch.cat((torch.zeros_like(real_dec_out), torch.ones_like(fake_dec_out)), 0) d_loss = nn.BCEWithLogitsLoss()(probs, targets) return {'loss': d_loss, 'log': {'loss_d_normal': d_loss}} def configure_optimizers(self): gen_params = torch.nn.ModuleList( [self.enc_x, self.dec_x, self.enc_y, self.dec_y]) discr_params = torch.nn.ModuleList([self.discr]) gen_optim = torch.optim.Adam(gen_params.parameters(), lr=3e-4, betas=(0.5, 0.9)) discr_optim = torch.optim.Adam(discr_params.parameters(), lr=3e-4, betas=(0.5, 0.9)) discriminator_sched = StepLR(discr_optim, step_size=5000, gamma=0.5) return [gen_optim, discr_optim], [discriminator_sched] def zero_grad(self): self.enc_x.zero_grad() self.dec_x.zero_grad() self.enc_y.zero_grad() self.dec_y.zero_grad() self.discr.zero_grad() def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, optimizer_closure): discr_step = (batch_nb % (self.discr_steps + self.gen_steps)) < \ self.discr_steps gen_step = (not discr_step) if optimizer_i == 0: if gen_step: optimizer.step() optimizer.zero_grad() self.zero_grad() if optimizer_i == 1: if discr_step: optimizer.step() optimizer.zero_grad() self.zero_grad() if optimizer_i > 1: optimizer.step() optimizer.zero_grad() self.zero_grad()