class GAN_Manager: def __init__(self, discriminator_in_nodes, generator_out_nodes, ps_model, ps_model_type, device): self.discriminator = Discriminator( in_nodes=discriminator_in_nodes).to(device) self.discriminator.apply(self.__weights_init) self.generator = Generator(out_nodes=generator_out_nodes).to(device) self.generator.apply(self.__weights_init) self.loss = nn.BCELoss() self.ps_model = ps_model self.ps_model_type = ps_model_type def get_generator(self): return self.generator def train_GAN(self, train_parameters, device): epochs = train_parameters["epochs"] train_set = train_parameters["train_set"] lr = train_parameters["lr"] shuffle = train_parameters["shuffle"] batch_size = train_parameters["batch_size"] BETA = train_parameters["BETA"] data_loader_train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle) g_optimizer = optim.Adam(self.generator.parameters(), lr=lr) d_optimizer = optim.Adam(self.discriminator.parameters(), lr=lr) for epoch in range(epochs): epoch += 1 total_G_loss = 0 total_D_loss = 0 total_prop_loss = 0 total_d_pred_real = 0 total_d_pred_fake = 0 for batch in data_loader_train: covariates_X_control, ps_score_control, y_f = batch covariates_X_control = covariates_X_control.to(device) covariates_X_control_size = covariates_X_control.size(0) ps_score_control = ps_score_control.squeeze().to(device) # 1. Train Discriminator real_data = covariates_X_control # Generate fake data fake_data = self.generator( self.__noise(covariates_X_control_size)).detach() # Train D d_error, d_pred_real, d_pred_fake = self.__train_discriminator( d_optimizer, real_data, fake_data) total_D_loss += d_error total_d_pred_real += d_pred_real total_d_pred_fake += d_pred_fake # 2. Train Generator # Generate fake data fake_data = self.generator( self.__noise(covariates_X_control_size)) # Train G error_g, prop_loss = self.__train_generator( g_optimizer, fake_data, BETA, ps_score_control, device) total_G_loss += error_g total_prop_loss += prop_loss if epoch % 1000 == 0: print( "Epoch: {0}, D_loss: {1}, D_score_real: {2}, D_score_Fake: {3}, G_loss: {4}, " "Prop_loss: {5}".format(epoch, total_D_loss, total_d_pred_real, total_d_pred_fake, total_G_loss, total_prop_loss)) def eval_GAN(self, eval_size, device): treated_g = self.generator(self.__noise(eval_size)) ps_score_list_treated = self.__get_propensity_score(treated_g, device) return treated_g, ps_score_list_treated def __cal_propensity_loss(self, ps_score_control, gen_treated, device): ps_score_list_treated = self.__get_propensity_score( gen_treated, device) ps_score_treated = torch.tensor(ps_score_list_treated).to(device) ps_score_control = ps_score_control.to(device) prop_loss = torch.sum((torch.sub(ps_score_treated.float(), ps_score_control.float()))**2) return prop_loss def __get_propensity_score(self, gen_treated, device): if self.ps_model_type == Constants.PS_MODEL_NN: return self.__get_propensity_score_NN(gen_treated, device) else: return self.__get_propensity_score_LR(gen_treated) def __get_propensity_score_LR(self, gen_treated): ps_score_list_treated = self.ps_model.predict_proba( gen_treated.cpu().detach().numpy())[:, -1].tolist() return ps_score_list_treated def __get_propensity_score_NN(self, gen_treated, device): # Assign Treated Y = np.ones(gen_treated.size(0)) eval_set = Utils.convert_to_tensor(gen_treated.cpu().detach().numpy(), Y) ps_eval_parameters_NN = {"eval_set": eval_set} ps_score_list_treated = self.ps_model.eval(ps_eval_parameters_NN, device, eval_from_GAN=True) return ps_score_list_treated @staticmethod def __noise(_size): n = Variable( torch.normal(mean=0, std=1, size=(_size, Constants.GAN_GENERATOR_IN_NODES))) # print(n.size()) if torch.cuda.is_available(): return n.cuda() return n @staticmethod def __weights_init(m): if type(m) == nn.Linear: nn.init.xavier_uniform_(m.weight) torch.nn.init.zeros_(m.bias) @staticmethod def __real_data_target(size): data = Variable(torch.ones(size, 1)) if torch.cuda.is_available(): return data.cuda() return data @staticmethod def __fake_data_target(size): data = Variable(torch.zeros(size, 1)) if torch.cuda.is_available(): return data.cuda() return data def __train_discriminator(self, optimizer, real_data, fake_data): # Reset gradients optimizer.zero_grad() # 1.1 Train on Real Data prediction_real = self.discriminator(real_data) real_score = torch.mean(prediction_real).item() # Calculate error and back propagate error_real = self.loss(prediction_real, self.__real_data_target(real_data.size(0))) error_real.backward() # 1.2 Train on Fake Data prediction_fake = self.discriminator(fake_data) fake_score = torch.mean(prediction_fake).item() # Calculate error and backpropagate error_fake = self.loss(prediction_fake, self.__fake_data_target(real_data.size(0))) error_fake.backward() # 1.3 Update weights with gradients optimizer.step() loss_D = error_real + error_fake # Return error return loss_D.item(), real_score, fake_score def __train_generator(self, optimizer, fake_data, BETA, ps_score_control, device): # 2. Train Generator # Reset gradients optimizer.zero_grad() # Sample noise and generate fake data predicted_D = self.discriminator(fake_data) # Calculate error and back propagate ps_score_control = ps_score_control.to(device) fake_data = fake_data.to(device) error_g = self.loss(predicted_D, self.__real_data_target(predicted_D.size(0))) prop_loss = self.__cal_propensity_loss(ps_score_control, fake_data, device) error = error_g + (BETA * prop_loss) error.backward() # Update weights with gradients optimizer.step() # Return error return error_g.item(), prop_loss.item()
# Load state dicts netG_A2B.load_state_dict(torch.load(opt.generator_A2B)) netG_B2A.load_state_dict(torch.load(opt.generator_B2A)) netD_A.load_state_dict(torch.load(opt.discriminator_A)) netD_B.load_state_dict(torch.load(opt.discriminator_B)) # Set model's test mode netG_A2B.eval() netG_B2A.eval() netD_A.eval() netD_B.eval() else: netG_A2B.apply(weights_init_normal) netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() criterion_BCE = torch.nn.BCEWithLogitsLoss() # Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr,