class RNN(object): def __init__(self, input_size, output_size): super(RNN, self).__init__() self.encoder = Encoder(input_size) self.decoder = Decoder(output_size) self.loss = nn.CrossEntropyLoss() self.encoder_optimizer = optim.Adam(self.encoder.parameters()) self.decoder_optimizer = optim.Adam(self.decoder.parameters()) sos, eos = torch.LongTensor(1, 1).zero_(), torch.LongTensor(1, 1).zero_() sos[0, 0], eos[0, 0] = 0, 1 self.sos, self.eos = sos, eos def train(self, input, target): target.insert(0, self.sos) target.append(self.eos) self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() # Encoder hidden_state = self.encoder.first_hidden() for ivec in input: _, hidden_state = self.encoder.forward(Variable(ivec), hidden_state) # Decoder total_loss, outputs = 0, [] for i in range(len(target) - 1): _, softmax, hidden_state = self.decoder.forward(Variable(target[i]), hidden_state) outputs.append(np.argmax(softmax.data.numpy(), 1)[:, np.newaxis]) total_loss += self.loss(softmax, Variable(target[i+1][0])) total_loss /= len(outputs) total_loss.backward() self.decoder_optimizer.step() self.encoder_optimizer.step() return total_loss.data[0], outputs # use total_loss.data[0] for version 0.3.0_4 and below, .item() for 0.4.0 def eval(self, input): hidden_state = self.encoder.first_hidden() # Encoder for ivec in input: _, hidden_state = self.encoder.forward(Variable(ivec), hidden_state) sentence = [] input = self.sos # Decoder while input.data[0, 0] != 1: output, _, hidden_state = self.decoder.forward(input, hidden_state) word = np.argmax(output.data.numpy()).reshape((1, 1)) input = Variable(torch.LongTensor(word)) sentence.append(word) return sentence def save(self): torch.save(self.encoder.state_dict(), "models/encoder.ckpt") torch.save(self.decoder.state_dict(), "models/decoder.ckpt")
class Trainer(object): def __init__(self, celeba_loader, config): # miscellaneous self.use_tensorboard = config.use_tensorboard self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # data loader self.dataload = celeba_loader # model configurations self.c64 = config.c64 self.c256 = config.c256 self.c2048 = config.c2048 self.rb6 = config.rb6 self.attr_dim = config.attr_dim self.hair_dim = config.hair_dim # training configurations self.selected_attrs = config.selected_attrs self.train_iters = config.train_iters self.num_iters_decay = config.num_iters_decay self.n_critic = config.n_critic self.d_lr = config.d_lr self.r_lr = config.r_lr self.t_lr = config.t_lr self.e_lr = config.e_lr self.decay_rate = config.decay_rate self.beta1 = config.beta1 self.beta2 = config.beta2 self.lambda_cls = config.lambda_cls self.lambda_cyc = config.lambda_cyc self.lambda_gp = config.lambda_gp # test configurations self.test_iters = config.test_iters # directories self.sample_dir = config.sample_dir self.model_save_dir = config.model_save_dir self.result_dir = config.result_dir self.log_dir = config.log_dir # step size self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.lr_update_step = config.lr_update_step # initial models self.build_models() if self.use_tensorboard: self.build_tensorboard() def build_models(self): self.E = Encoder(self.c64, self.rb6) self.T_Hair = Transformer(self.hair_dim, self.c256, self.rb6) self.T_Gender = Transformer(self.attr_dim, self.c256, self.rb6) self.T_Smailing = Transformer(self.attr_dim, self.c256, self.rb6) self.R = Reconstructor(self.c256) self.D_Hair = Discriminator(self.hair_dim, self.c64) self.D_Gender = Discriminator(self.attr_dim, self.c64) self.D_Smailing = Discriminator(self.attr_dim, self.c64) self.e_optim = torch.optim.Adam(self.E.parameters(), self.e_lr, [self.beta1, self.beta2]) self.th_optim = torch.optim.Adam(self.T_Hair.parameters(), self.t_lr, [self.beta1, self.beta2]) self.tg_optim = torch.optim.Adam(self.T_Gender.parameters(), self.t_lr, [self.beta1, self.beta2]) self.ts_optim = torch.optim.Adam(self.T_Smailing.parameters(), self.t_lr, [self.beta1, self.beta2]) self.r_optim = torch.optim.Adam(self.R.parameters(), self.r_lr, [self.beta1, self.beta2]) self.dh_optim = torch.optim.Adam(self.D_Hair.parameters(), self.d_lr, [self.beta1, self.beta2]) self.dg_optim = torch.optim.Adam(self.D_Gender.parameters(), self.d_lr, [self.beta1, self.beta2]) self.ds_optim = torch.optim.Adam(self.D_Smailing.parameters(), self.d_lr, [self.beta1, self.beta2]) self.print_network(self.E, 'Encoder') self.print_network(self.T_Hair, 'Transformer for Hair Color') self.print_network(self.T_Gender, 'Transformer for Gender') self.print_network(self.T_Smailing, 'Transformer for Smailing') self.print_network(self.R, 'Reconstructor') self.print_network(self.D_Hair, 'D for Hair Color') self.print_network(self.D_Gender, 'D for Gender') self.print_network(self.D_Smailing, 'D for Smailing') self.E.to(self.device) self.T_Hair.to(self.device) self.T_Gender.to(self.device) self.T_Smailing.to(self.device) self.R.to(self.device) self.D_Gender.to(self.device) self.D_Smailing.to(self.device) self.D_Hair.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(name) print("The number of parameters: {}".format(num_params)) print(model) def build_tensorboard(self): """Build a tensorboard logger.""" from logger import Logger self.logger = Logger(self.log_dir) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm-1)**2) def reset_grad(self): self.e_optim.zero_grad() self.th_optim.zero_grad() self.tg_optim.zero_grad() self.ts_optim.zero_grad() self.r_optim.zero_grad() self.dh_optim.zero_grad() self.dg_optim.zero_grad() self.ds_optim.zero_grad() def update_lr(self, e_lr, d_lr, r_lr, t_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.e_optim.param_groups: param_group['lr'] = e_lr for param_group in self.dh_optim.param_groups: param_group['lr'] = d_lr for param_group in self.dg_optim.param_groups: param_group['lr'] = d_lr for param_group in self.ds_optim.param_groups: param_group['lr'] = d_lr for param_group in self.r_optim.param_groups: param_group['lr'] = r_lr for param_group in self.th_optim.param_groups: param_group['lr'] = t_lr for param_group in self.tg_optim.param_groups: param_group['lr'] = t_lr for param_group in self.ts_optim.param_groups: param_group['lr'] = t_lr def create_labels(self, c_org, c_dim=5, selected_attrs=None): """Generate target domain labels for debugging and testing.""" # Get hair color indices. hair_color_indices = [] for i, attr_name in enumerate(selected_attrs): if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair']: hair_color_indices.append(i) c_trg_list = [] for i in range(c_dim): c_trg = c_org.clone() if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. c_trg[:, i] = 1 for j in hair_color_indices: if j != i: c_trg[:, j] = 0 else: c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. c_trg_list.append(c_trg.to(self.device)) return c_trg_list def denorm(self, x): """Convert the range from [-1, 1] to [0, 1].""" out = (x + 1) / 2 return out.clamp_(0, 1) def train(self): data_loader = self.dataload # Fetch fixed inputs for debugging. data_iter = iter(data_loader) x_fixed, c_org = next(data_iter) x_fixed = x_fixed.to(self.device) c_fixed_list = self.create_labels(c_org, 5, self.selected_attrs) d_lr = self.d_lr r_lr = self.r_lr t_lr = self.t_lr e_lr = self.e_lr # Start training print('Starting point==============================') start_time = time.time() for i in range(0, self.train_iters): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # # Fetch real images and labels try: x_real, label_real = next(data_iter) except: data_iter = iter(data_loader) x_real, label_real = next(data_iter) rand_idx = torch.randperm(label_real.size(0)) label_feak = label_real[rand_idx] x_real = x_real.to(self.device) # labels for hair color label_h_real = label_real[:, 0:3] label_h_feak = label_feak[:, 0:3] # labels for gender label_g_real = label_real[:, 3:4] label_g_feak = label_feak[:, 3:4] # labels for smailing label_s_real = label_real[:, 4:] label_s_feak = label_feak[:, 4:] label_h_real = label_h_real.to(self.device) label_h_feak = label_h_feak.to(self.device) label_g_real = label_g_real.to(self.device) label_g_feak = label_g_feak.to(self.device) label_s_real = label_s_real.to(self.device) label_s_feak = label_s_feak.to(self.device) # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Computer loss with real images h_src, h_cls = self.D_Hair(x_real) d_h_loss_real = -torch.mean(h_src) d_h_loss_cls = F.binary_cross_entropy_with_logits(h_cls, label_h_real, reduction='sum') / h_cls.size(0) g_src, g_cls = self.D_Gender(x_real) d_g_loss_real = -torch.mean(g_src) d_g_loss_cls = F.binary_cross_entropy_with_logits(g_cls, label_g_real, reduction='sum') / g_cls.size(0) s_src, s_cls = self.D_Smailing(x_real) d_s_loss_real = -torch.mean(s_src) d_s_loss_cls = F.binary_cross_entropy_with_logits(s_cls, label_s_real, reduction='sum') / s_cls.size(0) # Generate fake images and computer loss # Retrieve features of real image features = self.E(x_real) # Transform attributes from one value to an other t_h_features = self.T_Hair(features.detach(), label_h_feak) t_g_features = self.T_Gender(features.detach(), label_g_feak) t_s_features = self.T_Smailing(features.detach(), label_s_feak) # Reconstruct images from transformed attributes x_h_feak = self.R(t_h_features.detach()) x_g_feak = self.R(t_g_features.detach()) x_s_feak = self.R(t_s_features.detach()) # Computer loss with fake images h_src, h_cls = self.D_Hair(x_h_feak.detach()) d_h_loss_fake = torch.mean(h_src) g_src, g_cls = self.D_Gender(x_g_feak.detach()) d_g_loss_fake = torch.mean(g_src) s_src, s_cls = self.D_Smailing(x_s_feak.detach()) d_s_loss_fake = torch.mean(s_src) # Compute loss for gradient penalty alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) x_h_hat = (alpha * x_real.data + (1 - alpha) * x_h_feak.data).requires_grad_(True) #x_h_hat = (alpha * x_real.data + (1-alpha) * x_h_feak.data).requires_grad_(True).to(torch.float16) x_g_hat = (alpha * x_real.data + (1 - alpha) * x_g_feak.data).requires_grad_(True) #x_g_hat = (alpha * x_real.data + (1-alpha) * x_g_feak.data).requires_grad_(True).to(torch.float16) x_s_hat = (alpha * x_real.data + (1 - alpha) * x_s_feak.data).requires_grad_(True) #x_s_hat = (alpha * x_real.data + (1-alpha) * x_s_feak.data).requires_grad_(True).to(torch.float16) out_src, _ = self.D_Hair(x_h_hat) d_h_loss_gp = self.gradient_penalty(out_src, x_h_hat) out_src, _ = self.D_Gender(x_g_hat) d_g_loss_gp = self.gradient_penalty(out_src, x_g_hat) out_src, _ = self.D_Smailing(x_s_hat) d_s_loss_gp = self.gradient_penalty(out_src, x_s_hat) # Backward and optimize d_loss = d_h_loss_real + d_g_loss_real + d_s_loss_real + \ d_h_loss_fake + d_g_loss_fake + d_s_loss_fake + \ self.lambda_gp * (d_h_loss_gp + d_g_loss_gp + d_s_loss_gp) + \ self.lambda_cls * (d_h_loss_cls + d_g_loss_cls + d_s_loss_cls) #d_loss = d_h_loss_real + d_h_loss_fake + self.lambda_gp * d_h_loss_gp + self.lambda_cls * d_h_loss_cls self.reset_grad() d_loss.backward() self.dh_optim.step() self.dg_optim.step() self.ds_optim.step() # Logging loss = {} loss['D/h_loss_real'] = d_h_loss_real.item() loss['D/g_loss_real'] = d_g_loss_real.item() loss['D/s_loss_real'] = d_s_loss_real.item() loss['D/h_loss_fake'] = d_h_loss_fake.item() loss['D/g_loss_fake'] = d_g_loss_fake.item() loss['D/s_loss_fake'] = d_s_loss_fake.item() loss['D/h_loss_cls'] = d_h_loss_cls.item() loss['D/g_loss_cls'] = d_g_loss_cls.item() loss['D/s_loss_cls'] = d_s_loss_cls.item() loss['D/h_loss_gp'] = d_h_loss_gp.item() loss['D/g_loss_gp'] = d_g_loss_gp.item() loss['D/s_loss_gp'] = d_s_loss_gp.item() # =================================================================================== # # 3. Train the encoder, transformer and reconstructor # # =================================================================================== # if(i+1) % self.n_critic == 0: # Generate fake images and compute loss # Retrieve features of real image features = self.E(x_real) # Transform attributes from one value to an other t_h_features = self.T_Hair(features, label_h_feak) t_g_features = self.T_Gender(features, label_g_feak) t_s_features = self.T_Smailing(features, label_s_feak) # Reconstruct images from transformed attributes x_h_feak = self.R(t_h_features) x_g_feak = self.R(t_g_features) x_s_feak = self.R(t_s_features) # Computer loss with fake images h_src, h_cls = self.D_Hair(x_h_feak) etr_h_loss_fake = -torch.mean(h_src) etr_h_loss_cls = F.binary_cross_entropy_with_logits(h_cls, label_h_feak, reduction='sum') / h_cls.size(0) g_src, g_cls = self.D_Gender(x_g_feak) etr_g_loss_fake = -torch.mean(g_src) etr_g_loss_cls = F.binary_cross_entropy_with_logits(g_cls, label_g_feak, reduction='sum') / g_cls.size(0) s_src, s_cls = self.D_Smailing(x_s_feak) etr_s_loss_fake = -torch.mean(s_src) etr_s_loss_cls = F.binary_cross_entropy_with_logits(s_cls, label_s_feak, reduction='sum') / s_cls.size(0) # Real - Encoder - Reconstructor - Real loss x_re = self.R(features) er_loss_cyc = torch.mean(torch.abs(x_re - x_real)) # Real - Encoder - Transform, Real - Encoder - Transform - Reconstructor - Encoder loss h_fake_features = self.E(x_h_feak) g_fake_features = self.E(x_g_feak) s_fake_features = self.E(x_s_feak) etr_h_loss_cyc = torch.mean(torch.abs(t_h_features - h_fake_features)) etr_g_loss_cyc = torch.mean(torch.abs(t_g_features - g_fake_features)) etr_s_loss_cyc = torch.mean(torch.abs(t_s_features - s_fake_features)) # Backward and optimize etr_loss = etr_h_loss_fake + etr_g_loss_fake + etr_s_loss_fake + \ self.lambda_cls * (etr_h_loss_cls + etr_g_loss_cls + etr_s_loss_cls) + \ self.lambda_cyc * (er_loss_cyc + etr_h_loss_cyc + etr_g_loss_cyc + etr_s_loss_cyc) #etr_loss = etr_h_loss_fake + self.lambda_cls * etr_h_loss_cls + self.lambda_cyc * (er_loss_cyc + etr_h_loss_cyc) self.reset_grad() etr_loss.backward() self.e_optim.step() self.th_optim.step() self.tg_optim.step() self.ts_optim.step() self.r_optim.step() # Logging. loss['ETR/h_loss_fake'] = etr_h_loss_fake.item() loss['ETR/g_loss_fake'] = etr_g_loss_fake.item() loss['ETR/s_loss_fake'] = etr_s_loss_fake.item() loss['ETR/h_loss_cls'] = etr_h_loss_cls.item() loss['ETR/g_loss_cls'] = etr_g_loss_cls.item() loss['ETR/s_loss_cls'] = etr_s_loss_cls.item() loss['ER/er_loss_cyc'] = er_loss_cyc.item() loss['ETR/h_loss_cyc'] = etr_h_loss_cyc.item() loss['ETR/g_loss_cyc'] = etr_g_loss_cyc.item() loss['ETR/s_loss_cyc'] = etr_s_loss_cyc.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Translate fixed images for debugging. if (i + 1) % self.sample_step == 0: with torch.no_grad(): x_fake_list = [x_fixed] for c_fixed in c_fixed_list: xf = self.E(x_fixed) xth = self.T_Hair(xf, c_fixed[:, 0:3]) xtg = self.T_Gender(xth, c_fixed[:, 3:4]) xts = self.T_Smailing(xtg, c_fixed[:, 4:5]) x_fake_list.append(self.R(xts)) x_concat = torch.cat(x_fake_list, dim=3) sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i + 1)) save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format(sample_path)) # Print out training information. if (i + 1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.train_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i+1) # save model checkpoints if (i+1) % self.model_save_step == 0: E_path = os.path.join(self.model_save_dir, '{}-E.ckpt'.format(i+1)) D_h_path = os.path.join(self.model_save_dir, '{}-D_h.ckpt'.format(i+1)) D_g_path = os.path.join(self.model_save_dir, '{}-D_g.ckpt'.format(i+1)) D_s_path = os.path.join(self.model_save_dir, '{}-D_s.ckpt'.format(i+1)) R_path = os.path.join(self.model_save_dir, '{}-R.ckpt'.format(i+1)) T_h_path = os.path.join(self.model_save_dir, '{}-T_h.ckpt'.format(i+1)) T_g_path = os.path.join(self.model_save_dir, '{}-T_g.ckpt'.format(i+1)) T_s_path = os.path.join(self.model_save_dir, '{}-T_s.ckpt'.format(i+1)) torch.save(self.E.state_dict(), E_path) torch.save(self.D_Hair.state_dict(), D_h_path) torch.save(self.D_Gender.state_dict(), D_g_path) torch.save(self.D_Smailing.state_dict(), D_s_path) torch.save(self.R.state_dict(), R_path) torch.save(self.T_Hair.state_dict(), T_h_path) torch.save(self.T_Gender.state_dict(), T_g_path) torch.save(self.T_Smailing.state_dict(), T_s_path) print('Saved model checkpoints into {}...'.format(self.model_save_dir)) # decay learning rates if (i+1) % self.lr_update_step == 0 and (i+1) > self.num_iters_decay: e_lr -= (self.e_lr / float(self.decay_rate)) d_lr -= (self.d_lr / float(self.decay_rate)) r_lr -= (self.r_lr / float(self.decay_rate)) t_lr -= (self.t_lr / float(self.decay_rate)) self.update_lr(e_lr, d_lr, r_lr, t_lr) print ('Decayed learning rates, e_lr: {}, d_lr: {}, r_lr: {}, t_lr: {}.'.format(e_lr, d_lr, r_lr, t_lr))
class RNN(object): def __init__(self, input_size, output_size, resume=False): super(RNN, self).__init__() self.encoder = Encoder(input_size) self.decoder = Decoder(output_size) self.loss = nn.CrossEntropyLoss() self.encoder_optimizer = optim.Adam(self.encoder.parameters()) self.decoder_optimizer = optim.Adam(self.decoder.parameters()) if resume: self.encoder.load_state_dict(torch.load("models/encoder.ckpt")) self.decoder.load_state_dict(torch.load("models/decoder.ckpt")) def train(self, input, target): self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() # Encoder hidden_state = self.encoder.first_hidden() for ivec in input: _, hidden_state = self.encoder.forward(ivec, hidden_state) # Decoder total_loss, outputs = 0, [] for i in range(len(target) - 1): _, softmax, hidden_state = self.decoder.forward( target[i], hidden_state) outputs.append(np.argmax(softmax.data.numpy(), 1)[:, np.newaxis]) total_loss += self.loss(softmax, target[i + 1].squeeze(1)) total_loss /= len(outputs) total_loss.backward() self.decoder_optimizer.step() self.encoder_optimizer.step() return total_loss.data[0], outputs def eval(self, input): hidden_state = self.encoder.first_hidden() # Encoder for ivec in input: _, hidden_state = self.encoder.forward(Variable(ivec), hidden_state) sentence = [] input = self.sos # Decoder while input.data[0, 0] != 1: output, _, hidden_state = self.decoder.forward(input, hidden_state) word = np.argmax(output.data.numpy()).reshape((1, 1)) input = Variable(torch.LongTensor(word)) sentence.append(word) return sentence def save(self): torch.save(self.encoder.state_dict(), "models/encoder.ckpt") torch.save(self.decoder.state_dict(), "models/decoder.ckpt")
class BiGAN(nn.Module): def __init__(self,config): super(BiGAN,self).__init__() self._work_type = config.work_type self._epochs = config.epochs self._batch_size = config.batch_size self._encoder_lr = config.encoder_lr self._generator_lr = config.generator_lr self._discriminator_lr = config.discriminator_lr self._latent_dim = config.latent_dim self._weight_decay = config.weight_decay self._img_shape = (config.input_size,config.input_size) self._img_save_path = config.image_save_path self._model_save_path = config.model_save_path self._device = config.device if self._work_type == 'train': # Loss function self._adversarial_criterion = torch.nn.MSELoss() # Initialize generator, encoder and discriminator self._G = Generator(self._latent_dim,self._img_shape).to(self._device) self._E = Encoder(self._latent_dim,self._img_shape).to(self._device) self._D = Discriminator(self._latent_dim,self._img_shape).to(self._device) self._G.apply(self.weights_init) self._E.apply(self.weights_init) self._D.apply(self.discriminator_weights_init) self._G_optimizer = torch.optim.Adam([{'params' : self._G.parameters()},{'params' : self._E.parameters()}], lr=self._generator_lr,betas=(0.5,0.999),weight_decay=self._weight_decay) self._D_optimizer = torch.optim.Adam(self._D.parameters(),lr=self._discriminator_lr,betas=(0.5,0.999)) self._G_scheduler = lr_scheduler.ExponentialLR(self._G_optimizer, gamma= 0.99) self._D_scheduler = lr_scheduler.ExponentialLR(self._D_optimizer, gamma= 0.99) def train(self,train_loader): Tensor = torch.cuda.FloatTensor if self._device == 'cuda' else torch.FloatTensor n_total_steps = len(train_loader) for epoch in range(self._epochs): self._G_scheduler.step() self._D_scheduler.step() for i, (images, _) in enumerate(train_loader): # Adversarial ground truths valid = Variable(Tensor(images.size(0), 1).fill_(1), requires_grad=False) fake = Variable(Tensor(images.size(0), 1).fill_(0), requires_grad=False) # --------------------- # Train Encoder # --------------------- # Configure input images = images.reshape(-1,np.prod(self._img_shape)).to(self._device) # z_ is encoded latent vector (original_img,z_)= self._E(images) predict_encoder = self._D(original_img,z_) # --------------------- # Train Generator # --------------------- # Sample noise as generator input z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0],self._latent_dim)))) (gen_img,z)=self._G(z) predict_generator = self._D(gen_img,z) G_loss = (self._adversarial_criterion(predict_generator,valid)+self._adversarial_criterion(predict_encoder,fake)) *0.5 self._G_optimizer.zero_grad() G_loss.backward() self._G_optimizer.step() # --------------------- # Train Discriminator # --------------------- z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0],self._latent_dim)))) (gen_img,z)=self._G(z) (original_img,z_)= self._E(images) predict_encoder = self._D(original_img,z_) predict_generator = self._D(gen_img,z) D_loss = (self._adversarial_criterion(predict_encoder,valid)+self._adversarial_criterion(predict_generator,fake)) *0.5 self._D_optimizer.zero_grad() D_loss.backward() self._D_optimizer.step() if i % 100 == 0: print (f'Epoch [{epoch+1}/{self._epochs}], Step [{i+1}/{n_total_steps}]') print (f'Generator Loss: {G_loss.item():.4f} Discriminator Loss: {D_loss.item():.4f}') if i % 400 ==0: vutils.save_image(gen_img.unsqueeze(1).cpu().data[:64, ], f'{self._img_save_path}/E{epoch}_Iteration{i}_fake.png') vutils.save_image(original_img.unsqueeze(1).cpu().data[:64, ], f'{self._img_save_path}/E{epoch}_Iteration{i}_real.png') print('image saved') print('') if epoch % 100==0: torch.save(self._G.state_dict(), f'{self._model_save_path}/netG_{epoch}epoch.pth') torch.save(self._E.state_dict(), f'{self._model_save_path}/netE_{epoch}epoch.pth') torch.save(self._D.state_dict(), f'{self._model_save_path}/netD_{epoch}epoch.pth') def weights_init(self,m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: m.bias.data.fill_(0) def discriminator_weights_init(self,m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.5) m.bias.data.fill_(0) elif classname.find('Linear') != -1: m.bias.data.fill_(0)