def get_modules(opt): modules = {} disc = Discriminator() gen = Generator() clf = Classifier() if opt.cuda: disc = disc.cuda() gen = gen.cuda() clf = clf.cuda() modules['Discriminator'] = disc modules['Generator'] = gen modules['Classifier'] = clf return modules
class Solver(object): def __init__(self, data_loader, config): self.data_loader = data_loader self.noise_n = config.noise_n self.G_last_act = last_act(config.G_last_act) self.D_out_n = config.D_out_n self.D_last_act = last_act(config.D_last_act) self.G_lr = config.G_lr self.D_lr = config.D_lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.epoch = config.epoch self.batch_size = config.batch_size self.D_train_step = config.D_train_step self.save_image_step = config.save_image_step self.log_step = config.log_step self.model_save_step = config.model_save_step self.model_save_path = config.model_save_path self.log_save_path = config.log_save_path self.image_save_path = config.image_save_path self.use_tensorboard = config.use_tensorboard self.pretrained_model = config.pretrained_model self.build_model() if self.use_tensorboard is not None: self.build_tensorboard() if self.pretrained_model is not None: if len(self.pretrained_model) != 2: raise "must have both G and D pretrained parameters, and G is first, D is second" self.load_pretrained_model() def build_model(self): self.G = Generator(self.noise_n, self.G_last_act) self.D = Discriminator(self.D_out_n, self.D_last_act) self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr, [self.beta1, self.beta2]) self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.G.cuda() self.D.cuda() def build_tensorboard(self): from commons.logger import Logger self.logger = Logger(self.log_save_path) def load_pretrained_model(self): self.G.load_state_dict(torch.load(self.pretrained_model[0])) self.D.load_state_dict(torch.load(self.pretrained_model[1])) def reset_grad(self): self.G_optimizer.zero_grad() self.D_optimizer.zero_grad() def to_var(self, x, volatile=False): if torch.cuda.is_available(): x = x.cuda() return Variable(x, volatile=volatile) def train(self): bce_loss = nn.BCELoss() print(len(self.data_loader)) for e in range(self.epoch): for i, batch_images in enumerate(self.data_loader): batch_size = batch_images.size(0) real_x = self.to_var(batch_images) noise_x = self.to_var( torch.FloatTensor(noise_vector(batch_size, self.noise_n))) real_label = self.to_var( torch.FloatTensor(batch_size).fill_(1.)) fake_label = self.to_var( torch.FloatTensor(batch_size).fill_(0.)) # train D fake_x = self.G(noise_x) real_out = self.D(real_x) fake_out = self.D(fake_x.detach()) D_real = bce_loss(real_out, real_label) D_fake = bce_loss(fake_out, fake_label) D_loss = D_real + D_fake self.reset_grad() D_loss.backward() self.D_optimizer.step() # Log loss = {} loss['D/loss_real'] = D_real.data[0] loss['D/loss_fake'] = D_fake.data[0] loss['D/loss'] = D_loss.data[0] # Train G if (i + 1) % self.D_train_step == 0: # noise_x = self.to_var(torch.FloatTensor(noise_vector(batch_size, self.noise_n))) fake_out = self.D(self.G(noise_x)) G_loss = bce_loss(fake_out, real_label) self.reset_grad() G_loss.backward() self.G_optimizer.step() loss['G/loss'] = G_loss.data[0] # Print log if (i + 1) % self.log_step == 0: log = "Epoch: {}/{}, Iter: {}/{}".format( e + 1, self.epoch, i + 1, len(self.data_loader)) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * len(self.data_loader) + i + 1) # Save images if (e + 1) % self.save_image_step == 0: noise_x = self.to_var( torch.FloatTensor(noise_vector(32, self.noise_n))) fake_image = self.G(noise_x) save_image( fake_image.data, os.path.join(self.image_save_path, "{}_fake.png".format(e + 1))) if (e + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, "{}_G.pth".format(e + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, "{}_D.pth".format(e + 1)))
class ModuleTrain: def __init__(self, opt, best_loss=0.2): self.opt = opt self.best_loss = best_loss # 正确率这个值,才会保存模型 self.netd = Discriminator(self.opt) self.netg = Generator(self.opt) self.use_gpu = False # 加载模型 if os.path.exists(self.opt.netd_path): self.load_netd(self.opt.netd_path) else: print('[Load model] error: %s not exist !!!' % self.opt.netd_path) if os.path.exists(self.opt.netg_path): self.load_netg(self.opt.netg_path) else: print('[Load model] error: %s not exist !!!' % self.opt.netg_path) # DataLoader初始化 self.transform_train = T.Compose([ T.Resize((self.opt.img_size, self.opt.img_size)), T.ToTensor(), T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]), ]) train_dataset = ImageFolder(root=self.opt.data_path, transform=self.transform_train) self.train_loader = DataLoader(dataset=train_dataset, batch_size=self.opt.batch_size, shuffle=True, num_workers=self.opt.num_workers, drop_last=True) # 优化器和损失函数 # self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.5) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr1, betas=(self.opt.beta1, 0.999)) self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr2, betas=(self.opt.beta1, 0.999)) self.criterion = torch.nn.BCELoss() self.true_labels = Variable(torch.ones(self.opt.batch_size)) self.fake_labels = Variable(torch.zeros(self.opt.batch_size)) self.fix_noises = Variable( torch.randn(self.opt.batch_size, self.opt.nz, 1, 1)) self.noises = Variable( torch.randn(self.opt.batch_size, self.opt.nz, 1, 1)) # gpu or cpu if self.opt.use_gpu and torch.cuda.is_available(): self.use_gpu = True else: self.use_gpu = False if self.use_gpu: print('[use gpu] ...') self.netd.cuda() self.netg.cuda() self.criterion.cuda() self.true_labels = self.true_labels.cuda() self.fake_labels = self.fake_labels.cuda() self.fix_noises = self.fix_noises.cuda() self.noises = self.noises.cuda() else: print('[use cpu] ...') pass def train(self, save_best=True): print('[train] epoch: %d' % self.opt.max_epoch) for epoch_i in range(self.opt.max_epoch): loss_netd = 0.0 loss_netg = 0.0 correct = 0 print('================================================') for ii, (img, target) in enumerate(self.train_loader): # 训练 real_img = Variable(img) if self.opt.use_gpu: real_img = real_img.cuda() # 训练判别器 if (ii + 1) % self.opt.d_every == 0: self.optimizer_d.zero_grad() # 尽可能把真图片判别为1 output = self.netd(real_img) error_d_real = self.criterion(output, self.true_labels) error_d_real.backward() # 尽可能把假图片判别为0 self.noises.data.copy_( torch.randn(self.opt.batch_size, self.opt.nz, 1, 1)) fake_img = self.netg(self.noises).detach() # 根据噪声生成假图 fake_output = self.netd(fake_img) error_d_fake = self.criterion(fake_output, self.fake_labels) error_d_fake.backward() self.optimizer_d.step() loss_netd += (error_d_real.item() + error_d_fake.item()) # 训练生成器 if (ii + 1) % self.opt.g_every == 0: self.optimizer_g.zero_grad() self.noises.data.copy_( torch.randn(self.opt.batch_size, self.opt.nz, 1, 1)) fake_img = self.netg(self.noises) fake_output = self.netd(fake_img) # 尽可能让判别器把假图片也判别为1 error_g = self.criterion(fake_output, self.true_labels) error_g.backward() self.optimizer_g.step() loss_netg += error_g loss_netd /= (len(self.train_loader) * 2) loss_netg /= len(self.train_loader) print('[Train] Epoch: {} \tNetD Loss: {:.6f} \tNetG Loss: {:.6f}'. format(epoch_i, loss_netd, loss_netg)) if save_best is True: if (loss_netg + loss_netd) / 2 < self.best_loss: self.best_loss = (loss_netg + loss_netd) / 2 self.save(self.netd, self.opt.best_netd_path) # 保存最好的模型 self.save(self.netg, self.opt.best_netg_path) # 保存最好的模型 print('[save best] ...') # self.vis() if (epoch_i + 1) % 5 == 0: self.image_gan() self.save(self.netd, self.opt.netd_path) # 保存最好的模型 self.save(self.netg, self.opt.netg_path) # 保存最好的模型 def vis(self): fix_fake_imgs = self.netg(self.opt.fix_noises) visdom.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake') def image_gan(self): noises = torch.randn(self.opt.gen_search_num, self.opt.nz, 1, 1).normal_(self.opt.gen_mean, self.opt.gen_std) with torch.no_grad(): noises = Variable(noises) if self.use_gpu: noises = noises.cuda() fake_img = self.netg(noises) scores = self.netd(fake_img).data indexs = scores.topk(self.opt.gen_num)[1] result = list() for ii in indexs: result.append(fake_img.data[ii]) torchvision.utils.save_image(torch.stack(result), self.opt.gen_img, normalize=True, range=(-1, 1)) # # print(correct) # # print(len(self.train_loader.dataset)) # train_loss /= len(self.train_loader) # acc = float(correct) / float(len(self.train_loader.dataset)) # print('[Train] Epoch: {} \tLoss: {:.6f}\tAcc: {:.6f}\tlr: {}'.format(epoch_i, train_loss, acc, self.lr)) # # test_acc = self.test() # if save_best is True: # if test_acc > self.best_acc: # self.best_acc = test_acc # str_list = self.model_file.split('.') # best_model_file = "" # for str_index in range(len(str_list)): # best_model_file = best_model_file + str_list[str_index] # if str_index == (len(str_list) - 2): # best_model_file += '_best' # if str_index != (len(str_list) - 1): # best_model_file += '.' # self.save(best_model_file) # 保存最好的模型 # # self.save(self.model_file) def test(self): test_loss = 0.0 correct = 0 time_start = time.time() # 测试集 for data, target in self.test_loader: data, target = Variable(data), Variable(target) if self.use_gpu: data = data.cuda() target = target.cuda() output = self.model(data) # sum up batch loss if self.use_gpu: loss = self.loss(output, target) else: loss = self.loss(output, target) test_loss += loss.item() predict = torch.argmax(output, 1) correct += (predict == target).sum().data time_end = time.time() time_avg = float(time_end - time_start) / float( len(self.test_loader.dataset)) test_loss /= len(self.test_loader) acc = float(correct) / float(len(self.test_loader.dataset)) print('[Test] set: Test loss: {:.6f}\t Acc: {:.6f}\t time: {:.6f} \n'. format(test_loss, acc, time_avg)) return acc def load_netd(self, name): print('[Load model netd] %s ...' % name) self.netd.load_state_dict(torch.load(name)) def load_netg(self, name): print('[Load model netg] %s ...' % name) self.netg.load_state_dict(torch.load(name)) def save(self, model, name): print('[Save model] %s ...' % name) torch.save(model.state_dict(), name)
class Train(object): """ Main GAN trainer. Responsible for training the GAN and pre-training the generator autoencoder. """ def __init__(self, config): """ Construct a new GAN trainer :param Config config: The parsed network configuration. """ self.config = config LOG.info("CUDA version: {0}".format(version.cuda)) LOG.info("Creating data loader from path {0}".format(config.FILENAME)) self.data_loader = Data( config.FILENAME, config.BATCH_SIZE, polarisations=config.POLARISATIONS, # Polarisations to use frequencies=config.FREQUENCIES, # Frequencies to use max_inputs=config. MAX_SAMPLES, # Max inputs per polarisation and frequency normalise=config.NORMALISE) # Normalise inputs shape = self.data_loader.get_input_shape() width = shape[1] LOG.info("Creating models with input shape {0}".format(shape)) self._autoencoder = Autoencoder(width) self._discriminator = Discriminator(width) # TODO: Get correct input and output widths for generator self._generator = Generator(width, width) if config.USE_CUDA: LOG.info("Using CUDA") self.autoencoder = self._autoencoder.cuda() self.discriminator = self._discriminator.cuda() self.generator = self._generator.cuda() else: LOG.info("Using CPU") self.autoencoder = self._autoencoder self.discriminator = self._discriminator self.generator = self._generator def check_requeue(self, epochs_complete): """ Check and re-queue the training script if it has completed the desired number of training epochs per session :param int epochs_complete: Number of epochs completed :return: True if the script has been requeued, False if not :rtype bool """ if self.config.REQUEUE_EPOCHS > 0: if epochs_complete >= self.config.REQUEUE_EPOCHS: # We've completed enough epochs for this instance. We need to kill it and requeue LOG.info( "REQUEUE_EPOCHS of {0} met, calling REQUEUE_SCRIPT".format( self.config.REQUEUE_EPOCHS)) subprocess.call(self.config.REQUEUE_SCRIPT, shell=True, cwd=os.path.dirname( self.config.REQUEUE_SCRIPT)) return True # Requeue performed return False # No requeue needed def load_state(self, checkpoint, module, optimiser=None): """ Load the provided checkpoint into the provided module and optimiser. This function checks whether the load threw an exception and logs it to the user. :param Checkpoint checkpoint: The checkpoint to load :param module: The pytorch module to load the checkpoint into. :param optimiser: The pytorch optimiser to load the checkpoint into. :return: None if the load failed, int number of epochs in the checkpoint if load succeeded """ try: module.load_state_dict(checkpoint.module_state) if optimiser is not None: optimiser.load_state_dict(checkpoint.optimiser_state) return checkpoint.epoch except RuntimeError as e: LOG.exception( "Error loading module state. This is most likely an input size mismatch. Please delete the old module saved state, or change the input size" ) return None def close(self): """ Close the data loader used by the trainer. """ self.data_loader.close() def generate_labels(self, num_samples, pattern): """ Generate labels for the discriminator. :param int num_samples: Number of input samples to generate labels for. :param list pattern: Pattern to generator. Should be either [1, 0], or [0, 1] :return: New labels for the discriminator """ var = torch.FloatTensor([pattern] * num_samples) return var.cuda() if self.config.USE_CUDA else var def _train_autoencoder(self): """ Main training loop for the autencoder. This function will return False if: - Loading the autoencoder succeeded, but the NN model did not load the state dicts correctly. - The script needs to be re-queued because the NN has been trained for REQUEUE_EPOCHS :return: True if training was completed, False if training needs to continue. :rtype bool """ criterion = nn.SmoothL1Loss() optimiser = optim.Adam(self.generator.parameters(), lr=0.00003, betas=(0.5, 0.999)) checkpoint = Checkpoint("autoencoder") epoch = 0 if checkpoint.load(): epoch = self.load_state(checkpoint, self.autoencoder, optimiser) if epoch is not None and epoch >= self.config.MAX_AUTOENCODER_EPOCHS: LOG.info("Autoencoder already trained") return True else: LOG.info( "Autoencoder training beginning from epoch {0}".format( epoch)) else: LOG.info('Autoencoder checkpoint not found. Training from start') # Train autoencoder self._autoencoder.set_mode(Autoencoder.Mode.AUTOENCODER) vis_path = os.path.join( os.path.splitext(self.config.FILENAME)[0], "autoencoder", str(datetime.now())) with Visualiser(vis_path) as vis: epochs_complete = 0 while epoch < self.config.MAX_AUTOENCODER_EPOCHS: if self.check_requeue(epochs_complete): return False # Requeue needed and training not complete for step, (data, _, _) in enumerate(self.data_loader): if self.config.USE_CUDA: data = data.cuda() if self.config.ADD_DROPOUT: # Drop out parts of the input, but compute loss on the full input. out = self.autoencoder(nn.functional.dropout( data, 0.5)) else: out = self.autoencoder(data) loss = criterion(out.cpu(), data.cpu()) self.autoencoder.zero_grad() loss.backward() optimiser.step() vis.step_autoencoder(loss.item()) # Report data and save checkpoint fmt = "Epoch [{0}/{1}], Step[{2}/{3}], loss: {4:.4f}" LOG.info( fmt.format(epoch + 1, self.config.MAX_AUTOENCODER_EPOCHS, step, len(self.data_loader), loss)) epoch += 1 epochs_complete += 1 checkpoint.set(self.autoencoder.state_dict(), optimiser.state_dict(), epoch).save() LOG.info("Plotting autoencoder progress") vis.plot_training(epoch) data, _, _ = iter(self.data_loader).__next__() vis.test_autoencoder(epoch, self.autoencoder, data.cuda()) LOG.info("Autoencoder training complete") return True # Training complete def _train_gan(self): """ TODO: Add in autoencoder to perform dimensionality reduction on data TODO: Not working yet - trying to work out good autoencoder model first :return: """ criterion = nn.BCELoss() discriminator_optimiser = optim.Adam(self.discriminator.parameters(), lr=0.003, betas=(0.5, 0.999)) discriminator_scheduler = optim.lr_scheduler.LambdaLR( discriminator_optimiser, lambda epoch: 0.97**epoch) discriminator_checkpoint = Checkpoint("discriminator") discriminator_epoch = 0 if discriminator_checkpoint.load(): discriminator_epoch = self.load_state(discriminator_checkpoint, self.discriminator, discriminator_optimiser) else: LOG.info('Discriminator checkpoint not found') generator_optimiser = optim.Adam(self.generator.parameters(), lr=0.003, betas=(0.5, 0.999)) generator_scheduler = optim.lr_scheduler.LambdaLR( generator_optimiser, lambda epoch: 0.97**epoch) generator_checkpoint = Checkpoint("generator") generator_epoch = 0 if generator_checkpoint.load(): generator_epoch = self.load_state(generator_checkpoint, self.generator, generator_optimiser) else: LOG.info('Generator checkpoint not found') if discriminator_epoch is None or generator_epoch is None: epoch = 0 LOG.info( "Discriminator or generator failed to load, training from start" ) else: epoch = min(generator_epoch, discriminator_epoch) LOG.info("Generator loaded at epoch {0}".format(generator_epoch)) LOG.info("Discriminator loaded at epoch {0}".format( discriminator_epoch)) LOG.info("Training from lowest epoch {0}".format(epoch)) vis_path = os.path.join( os.path.splitext(self.config.FILENAME)[0], "gan", str(datetime.now())) with Visualiser(vis_path) as vis: real_labels = None # all 1s fake_labels = None # all 0s epochs_complete = 0 while epoch < self.config.MAX_EPOCHS: if self.check_requeue(epochs_complete): return # Requeue needed and training not complete for step, (data, noise1, noise2) in enumerate(self.data_loader): batch_size = data.size(0) if real_labels is None or real_labels.size( 0) != batch_size: real_labels = self.generate_labels(batch_size, [1.0]) if fake_labels is None or fake_labels.size( 0) != batch_size: fake_labels = self.generate_labels(batch_size, [0.0]) if self.config.USE_CUDA: data = data.cuda() noise1 = noise1.cuda() noise2 = noise2.cuda() # ============= Train the discriminator ============= # Pass real noise through first - ideally the discriminator will return 1 #[1, 0] d_output_real = self.discriminator(data) # Pass generated noise through - ideally the discriminator will return 0 #[0, 1] d_output_fake1 = self.discriminator(self.generator(noise1)) # Determine the loss of the discriminator by adding up the real and fake loss and backpropagate d_loss_real = criterion( d_output_real, real_labels ) # How good the discriminator is on real input d_loss_fake = criterion( d_output_fake1, fake_labels ) # How good the discriminator is on fake input d_loss = d_loss_real + d_loss_fake self.discriminator.zero_grad() d_loss.backward() discriminator_optimiser.step() # =============== Train the generator =============== # Pass in fake noise to the generator and get it to generate "real" noise # Judge how good this noise is with the discriminator d_output_fake2 = self.discriminator(self.generator(noise2)) # Determine the loss of the generator using the discriminator and backpropagate g_loss = criterion(d_output_fake2, real_labels) self.discriminator.zero_grad() self.generator.zero_grad() g_loss.backward() generator_optimiser.step() vis.step(d_loss_real.item(), d_loss_fake.item(), g_loss.item()) # Report data and save checkpoint fmt = "Epoch [{0}/{1}], Step[{2}/{3}], d_loss_real: {4:.4f}, d_loss_fake: {5:.4f}, g_loss: {6:.4f}" LOG.info( fmt.format(epoch + 1, self.config.MAX_EPOCHS, step + 1, len(self.data_loader), d_loss_real, d_loss_fake, g_loss)) epoch += 1 epochs_complete += 1 discriminator_checkpoint.set( self.discriminator.state_dict(), discriminator_optimiser.state_dict(), epoch).save() generator_checkpoint.set(self.generator.state_dict(), generator_optimiser.state_dict(), epoch).save() vis.plot_training(epoch) data, noise1, _ = iter(self.data_loader).__next__() if self.config.USE_CUDA: data = data.cuda() noise1 = noise1.cuda() vis.test(epoch, self.data_loader.get_input_size_first(), self.discriminator, self.generator, noise1, data) generator_scheduler.step(epoch) discriminator_scheduler.step(epoch) LOG.info("Learning rates: d {0} g {1}".format( discriminator_optimiser.param_groups[0]["lr"], generator_optimiser.param_groups[0]["lr"])) LOG.info("GAN Training complete") def __call__(self): """ Main training loop for the GAN. The training process is interruptable; the model and optimiser states are saved to disk each epoch, and the latest states are restored when the trainer is resumed. If the script is not able to load the generator's saved state, it will attempt to load the pre-trained generator autoencoder from the generator_decoder_complete checkpoint (if it exists). If this also fails, the generator is pre-trained as an autoencoder. This training is also interruptable, and will produce the generator_decoder_complete checkpoint on completion. On successfully restoring generator and discriminator state, the trainer will proceed from the earliest restored epoch. For example, if the generator is restored from epoch 7 and the discriminator is restored from epoch 5, training will proceed from epoch 5. Visualisation plots are produces each epoch and stored in /path_to_input_file_directory/{gan/generator_auto_encoder}/{timestamp}/{epoch} Each time the trainer is run, it creates a new timestamp directory using the current time. """ # Load the autoencoder, and train it if needed. if not self._train_autoencoder(): # Autoencoder training incomplete return
def decode(self, z): h3 = self.decoder(z) return h3 def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) # z = z.view(args.batch_size, z_dim, 1, 1) z = z.view(args.batch_size, z_dim) return self.decode(z), mu, logvar model = VAE() discrim = Discriminator() if use_cuda: model.cuda() discrim.cuda() optimizer = optim.Adam(model.parameters(), lr = args.lr) discrim_optimizer = optim.Adam(discrim.parameters(), lr = args.discrim_lr) # Reconstruction + KL divergence losses summed over all elements and batch A, B, C = 224, 224, 3 image_size = A * B * C def loss_function(recon_x, x, mu, logvar): # BCE = F.binary_cross_entropy(recon_x.view(-1, image_size), x.view(-1, image_size), size_average=False) # BCE = F.binary_cross_entropy(recon_x, x) ## define the GAN loss here label = Variable(torch.ones(args.batch_size).type(Tensor)) ## TODO: see if its a variable discrim_output = discrim(recon_x) BCE = F.binary_cross_entropy(discrim_output, label)
data_itr_src = get_data_iter("MNIST", train=True) data_itr_tgt = get_data_iter("USPS", train=True) pos_labels = Variable(torch.Tensor([1])) neg_lables = Variable(torch.Tensor([-1])) g_step = 0 g_loss_durations = [] d_loss_durations = [] c_loss_durations = [] # take variable into cuda if use_cuda: generator.cuda() generator_larger.cuda() critic.cuda() classifier.cuda() pos_labels = pos_labels.cuda() neg_lables = neg_lables.cuda() # for 循环 for epoch in range(params.num_epochs): # break # 训练 鉴别器 # 开启求 鉴别器的梯度 for p in critic.parameters(): p.requires_grad = True # 设置 鉴别器的训练步数 if g_step < 25 or g_step % 500 == 0: # this helps to start with the critic at optimum # even in the first iterations.
class Solver(object): def __init__(self, data_loader, config): self.data_loader = data_loader self.noise_n = config.noise_n self.G_last_act = last_act(config.G_last_act) self.D_out_n = config.D_out_n self.D_last_act = last_act(config.D_last_act) self.G_lr = config.G_lr self.D_lr = config.D_lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.epoch = config.epoch self.batch_size = config.batch_size self.D_train_step = config.D_train_step self.save_image_step = config.save_image_step self.log_step = config.log_step self.model_save_step = config.model_save_step self.clip_value = config.clip_value self.lambda_gp = config.lambda_gp self.model_save_path = config.model_save_path self.log_save_path = config.log_save_path self.image_save_path = config.image_save_path self.use_tensorboard = config.use_tensorboard self.pretrained_model = config.pretrained_model self.build_model() if self.use_tensorboard is not None: self.build_tensorboard() if self.pretrained_model is not None: if len(self.pretrained_model) != 2: raise "must have both G and D pretrained parameters, and G is first, D is second" self.load_pretrained_model() def build_model(self): self.G = Generator(self.noise_n, self.G_last_act) self.D = Discriminator(self.D_out_n, self.D_last_act) self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr, [self.beta1, self.beta2]) self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.G.cuda() self.D.cuda() def build_tensorboard(self): from commons.logger import Logger self.logger = Logger(self.log_save_path) def load_pretrained_model(self): self.G.load_state_dict(torch.load(self.pretrained_model[0])) self.D.load_state_dict(torch.load(self.pretrained_model[1])) def denorm(self, x): out = (x + 1) / 2 return out.clamp_(0, 1) def reset_grad(self): self.G_optimizer.zero_grad() self.D_optimizer.zero_grad() def to_var(self, x, volatile=False): if torch.cuda.is_available(): x = x.cuda() return Variable(x, volatile=volatile) def train(self): print(len(self.data_loader)) for e in range(self.epoch): for i, batch_images in enumerate(self.data_loader): batch_size = batch_images.size(0) label = torch.FloatTensor(batch_size) real_x = self.to_var(batch_images) noise_x = self.to_var( torch.FloatTensor(noise_vector(batch_size, self.noise_n))) # train D fake_x = self.G(noise_x) real_out = self.D(real_x) fake_out = self.D(fake_x.detach()) D_real = -torch.mean(real_out) D_fake = torch.mean(fake_out) D_loss = D_real + D_fake self.reset_grad() D_loss.backward() self.D_optimizer.step() # Log loss = {} loss['D/loss_real'] = D_real.data[0] loss['D/loss_fake'] = D_fake.data[0] loss['D/loss'] = D_loss.data[0] # choose one in below two # Clip weights of D # for p in self.D.parameters(): # p.data.clamp_(-self.clip_value, clip_value) # Gradients penalty, WGAP-GP alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x) # print(alpha.shape, real_x.shape, fake_x.shape) interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True) gp_out = self.D(interpolated) grad = torch.autograd.grad(outputs=gp_out, inputs=interpolated, grad_outputs=torch.ones( gp_out.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # Backward + Optimize d_loss = self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.D_optimizer.step() # Train G if (i + 1) % self.D_train_step == 0: fake_out = self.D(self.G(noise_x)) G_loss = -torch.mean(fake_out) self.reset_grad() G_loss.backward() self.G_optimizer.step() loss['G/loss'] = G_loss.data[0] # Print log if (i + 1) % self.log_step == 0: log = "Epoch: {}/{}, Iter: {}/{}".format( e + 1, self.epoch, i + 1, len(self.data_loader)) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary( tag, value, e * len(self.data_loader) + i + 1) # Save images if (e + 1) % self.save_image_step == 0: noise_x = self.to_var( torch.FloatTensor(noise_vector(16, self.noise_n))) fake_image = self.G(noise_x) save_image( self.denorm(fake_image.data), os.path.join(self.image_save_path, "{}_fake.png".format(e + 1))) if (e + 1) % self.model_save_step == 0: torch.save( self.G.state_dict(), os.path.join(self.model_save_path, "{}_G.pth".format(e + 1))) torch.save( self.D.state_dict(), os.path.join(self.model_save_path, "{}_D.pth".format(e + 1)))
transform=T.ToTensor(), remove_alpha=True) test_dataset = TestFromFolder(os.path.join(all_datasets, 'stage1_test/loc.csv'), transform=T.ToTensor(), remove_alpha=True) """ ----------------- ----- Model ----- ----------------- """ generator = UNet(3, 1) discriminator = Discriminator(4, 1) generator.cuda() discriminator.cuda() # lr = 0.001 seems to work WITHOUT PRETRAINING g_optim = optim.Adam(generator.parameters(), lr=0.001) d_optim = optim.Adam(discriminator.parameters(), lr=0.001) #g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(g_optim, factor=0.1, verbose=True, patience=5) #d_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(d_optim, factor=0.1, verbose=True, patience=5) gan = GAN( g=generator, d=discriminator, g_optim=g_optim, d_optim=d_optim, g_loss=nn.MSELoss().cuda(), d_loss=nn.MSELoss().cuda(), #g_scheduler=g_scheduler, d_scheduler=d_scheduler )
def main(): parser = argparse.ArgumentParser() parser.add_argument('--epoch', type=int, default=0, help='starting epoch') parser.add_argument('--n_epochs', type=int, default=400, help='number of epochs of training') parser.add_argument('--batchSize', type=int, default=10, help='size of the batches') parser.add_argument('--dataroot', type=str, default='datasets/genderchange/', help='root directory of the dataset') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') parser.add_argument( '--decay_epoch', type=int, default=100, help='epoch to start linearly decaying the learning rate to 0') parser.add_argument('--size', type=int, default=256, help='size of the data crop (squared assumed)') parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') parser.add_argument('--cuda', action='store_true', help='use GPU computation') parser.add_argument( '--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') opt = parser.parse_args() print(opt) if torch.cuda.is_available() and not opt.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) ###### Definition of variables ###### # Networks netG_A2B = Generator(opt.input_nc, opt.output_nc) netG_B2A = Generator(opt.output_nc, opt.input_nc) netD_A = Discriminator(opt.input_nc) netD_B = Discriminator(opt.output_nc) if opt.cuda: netG_A2B.cuda() netG_B2A.cuda() netD_A.cuda() netD_B.cuda() netG_A2B.apply(weights_init_normal) netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() # Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Dataset loader transforms_ = [ transforms.Resize(int(opt.size * 1.2), Image.BICUBIC), transforms.CenterCrop(opt.size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu, drop_last=True) # Plot Loss and Images in Tensorboard experiment_dir = 'logs/{}@{}'.format( opt.dataroot.split('/')[1], datetime.now().strftime("%d.%m.%Y-%H:%M:%S")) os.makedirs(experiment_dir, exist_ok=True) writer = SummaryWriter(os.path.join(experiment_dir, "tb")) metric_dict = defaultdict(list) n_iters_total = 0 ################################### ###### Training ###### for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): # Set model input real_A = Variable(input_A.copy_(batch['A'])) real_B = Variable(input_B.copy_(batch['B'])) ###### Generators A2B and B2A ###### optimizer_G.zero_grad() # Identity loss # G_A2B(B) should equal B if real B is fed same_B = netG_A2B(real_B) loss_identity_B = criterion_identity( same_B, real_B) * 5.0 # [batchSize, 3, ImgSize, ImgSize] # G_B2A(A) should equal A if real A is fed same_A = netG_B2A(real_A) loss_identity_A = criterion_identity( same_A, real_A) * 5.0 # [batchSize, 3, ImgSize, ImgSize] # GAN loss fake_B = netG_A2B(real_A) pred_fake = netD_B(fake_B).view(-1) loss_GAN_A2B = criterion_GAN(pred_fake, target_real) # [batchSize] fake_A = netG_B2A(real_B) pred_fake = netD_A(fake_A).view(-1) loss_GAN_B2A = criterion_GAN(pred_fake, target_real) # [batchSize] # Cycle loss recovered_A = netG_B2A(fake_B) loss_cycle_ABA = criterion_cycle( recovered_A, real_A) * 10.0 # [batchSize, 3, ImgSize, ImgSize] recovered_B = netG_A2B(fake_A) loss_cycle_BAB = criterion_cycle( recovered_B, real_B) * 10.0 # [batchSize, 3, ImgSize, ImgSize] # Total loss loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB loss_G.backward() optimizer_G.step() ################################### ###### Discriminator A ###### optimizer_D_A.zero_grad() # Real loss pred_real = netD_A(real_A).view(-1) loss_D_real = criterion_GAN(pred_real, target_real) # [batchSize] # Fake loss fake_A = fake_A_buffer.push_and_pop(fake_A) pred_fake = netD_A(fake_A.detach()).view(-1) loss_D_fake = criterion_GAN(pred_fake, target_fake) # [batchSize] # Total loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A.backward() optimizer_D_A.step() ################################### ###### Discriminator B ###### optimizer_D_B.zero_grad() # Real loss pred_real = netD_B(real_B).view(-1) loss_D_real = criterion_GAN(pred_real, target_real) # [batchSize] # Fake loss fake_B = fake_B_buffer.push_and_pop(fake_B) pred_fake = netD_B(fake_B.detach()).view(-1) loss_D_fake = criterion_GAN(pred_fake, target_fake) # [batchSize] # Total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B.backward() optimizer_D_B.step() ################################### metric_dict['loss_G'].append(loss_G.item()) metric_dict['loss_G_identity'].append(loss_identity_A.item() + loss_identity_B.item()) metric_dict['loss_G_GAN'].append(loss_GAN_A2B.item() + loss_GAN_B2A.item()) metric_dict['loss_G_cycle'].append(loss_cycle_ABA.item() + loss_cycle_BAB.item()) metric_dict['loss_D'].append(loss_D_A.item() + loss_D_B.item()) for title, value in metric_dict.items(): writer.add_scalar('train/{}'.format(title), value[-1], n_iters_total) n_iters_total += 1 print(""" ----------------------------------------------------------- Epoch : {} Finished Loss_G : {} Loss_G_identity : {} Loss_G_GAN : {} Loss_G_cycle : {} Loss_D : {} ----------------------------------------------------------- """.format(epoch, loss_G, loss_identity_A + loss_identity_B, loss_GAN_A2B + loss_GAN_B2A, loss_cycle_ABA + loss_cycle_BAB, loss_D_A + loss_D_B)) # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() # Save models checkpoints if loss_G.item() < 2.5: os.makedirs(os.path.join(experiment_dir, str(epoch)), exist_ok=True) torch.save(netG_A2B.state_dict(), '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch)) torch.save(netG_B2A.state_dict(), '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch)) torch.save(netD_A.state_dict(), '{}/{}/netD_A.pth'.format(experiment_dir, epoch)) torch.save(netD_B.state_dict(), '{}/{}/netD_B.pth'.format(experiment_dir, epoch)) elif epoch > 100 and epoch % 40 == 0: os.makedirs(os.path.join(experiment_dir, str(epoch)), exist_ok=True) torch.save(netG_A2B.state_dict(), '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch)) torch.save(netG_B2A.state_dict(), '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch)) torch.save(netD_A.state_dict(), '{}/{}/netD_A.pth'.format(experiment_dir, epoch)) torch.save(netD_B.state_dict(), '{}/{}/netD_B.pth'.format(experiment_dir, epoch)) for title, value in metric_dict.items(): writer.add_scalar("train/{}_epoch".format(title), np.mean(value), epoch)
def main(args): with open(args.params, "r") as f: params = json.load(f) generator = Generator(params["dim_latent"]) discriminator = Discriminator() if args.device is not None: generator = generator.cuda(args.device) discriminator = discriminator.cuda(args.device) # dataloading train_dataset = datasets.MNIST(root=args.datadir, transform=transforms.ToTensor(), train=True) train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], num_workers=4, shuffle=True) # optimizer betas = (params["beta_1"], params["beta_2"]) optimizer_G = optim.Adam(generator.parameters(), lr=params["learning_rate"], betas=betas) optimizer_D = optim.Adam(discriminator.parameters(), lr=params["learning_rate"], betas=betas) if not os.path.exists(args.modeldir): os.mkdir(args.modeldir) if not os.path.exists(args.logdir): os.mkdir(args.logdir) writer = SummaryWriter(args.logdir) steps_per_epoch = len(train_loader) msg = ["\t{0}: {1}".format(key, val) for key, val in params.items()] print("hyperparameters: \n" + "\n".join(msg)) # main training loop for n in range(params["num_epochs"]): loader = iter(train_loader) print("epoch: {0}/{1}".format(n + 1, params["num_epochs"])) for i in tqdm.trange(steps_per_epoch): batch, _ = next(loader) if args.device is not None: batch = batch.cuda(args.device) loss_D = update_discriminator(batch, discriminator, generator, optimizer_D, params) loss_G = update_generator(discriminator, generator, optimizer_G, params, args.device) writer.add_scalar("loss_discriminator/train", loss_D, i + n * steps_per_epoch) writer.add_scalar("loss_generator/train", loss_G, i + n * steps_per_epoch) torch.save(generator.state_dict(), args.o + ".generator." + str(n) + ".tmp") torch.save(discriminator.state_dict(), args.o + ".discriminator." + str(n) + ".tmp") # eval with torch.no_grad(): latent = torch.randn(args.num_fake_samples_eval, params["dim_latent"]).cuda() imgs_fake = generator(latent) writer.add_images("generated fake images", imgs_fake, n) del latent, imgs_fake writer.close() torch.save(generator.state_dict(), args.o + ".generator.pt") torch.save(discriminator.state_dict(), args.o + ".discriminator.pt")
else: print('Loading target model from {}'.format(args.model_target_path)) model_target = torch.load(args.model_target_path) if not (args.resume and os.path.isfile(args.model_source_path)): print('Creating new source model') model_source = VAE2() else: print('Loading source model from {}'.format(args.model_source_path)) model_source = torch.load(args.model_source_path) discriminator_model = Discriminator(20, 20) if args.cuda: model_target.cuda() model_source.cuda() discriminator_model.cuda() # target_optimizer_encoder_params = [{'params': model_target.fc1.parameters()}, {'params': model_target.fc2.parameters()}] target_optimizer = optim.Adam(model_target.parameters(), lr=args.lr) # target_optimizer_encoder = optim.Adam(target_optimizer_encoder_params, lr=args.lr) source_optimizer = optim.Adam(model_source.parameters(), lr=args.lr) d_optimizer = optim.Adam(discriminator_model.parameters(), lr=args.lr) criterion = nn.BCELoss() if args.source == 'mnist': tests = Tests(model_source, model_target, classifyMNIST, 'mnist', 'fashionMnist', args, graph) elif args.source == 'fashionMnist': tests = Tests(model_source, model_target, classifyMNIST, 'fashionMnist', 'mnist', args, graph)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--cuda', default=False, action='store_true', help='Enable CUDA') args = parser.parse_args() use_cuda = True if args.cuda and torch.cuda.is_available() else False netG = Generator(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda) netD = Discriminator(VOCAB_SIZE, D_EMB_SIZE, D_NUM_CLASSES, D_FILTER_SIZES, D_NUM_FILTERS, DROPOUT, use_cuda) oracle = Oracle(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda) if use_cuda: netG, netD, oracle = netG.cuda(), netD.cuda(), oracle.cuda() netG.create_optim(G_LR) netD.create_optim(D_LR, D_L2_REG) # generating synthetic data print('Generating data...') generate_samples(oracle, BATCH_SIZE, GENERATED_NUM, REAL_FILE) # pretrain generator gen_set = GeneratorDataset(REAL_FILE) genloader = DataLoader(dataset=gen_set, batch_size=BATCH_SIZE, shuffle=True) print('\nPretraining generator...\n') for epoch in range(PRE_G_EPOCHS): loss = netG.pretrain(genloader) print('Epoch {} pretrain generator training loss: {}'.format( epoch + 1, loss)) generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} pretrain generator val loss: {}'.format( epoch + 1, loss)) # pretrain discriminator print('\nPretraining discriminator...\n') for epoch in range(PRE_D_EPOCHS): generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for k_step in range(K_STEPS): loss = netD.dtrain(disloader) print( 'Epoch {} K-step {} pretrain discriminator training loss: {}'. format(epoch + 1, k_step + 1, loss)) print('\nStarting adversarial training...') for epoch in range(TOTAL_EPOCHS): nets = [copy.deepcopy(netG) for _ in range(POPULATION_SIZE)] population = [(net, evaluate(net, netD)) for net in nets] for g_step in range(G_STEPS): t_start = time.time() population.sort(key=lambda p: p[1], reverse=True) rewards = [p[1] for p in population[:PARENTS_COUNT]] reward_mean = np.mean(rewards) reward_max = np.max(rewards) reward_std = np.std(rewards) print( "Epoch %d step %d: reward_mean=%.2f, reward_max=%.2f, reward_std=%.2f, time=%.2f s" % (epoch, g_step, reward_mean, reward_max, reward_std, time.time() - t_start)) elite = population[0] # generate next population prev_population = population population = [elite] for _ in range(POPULATION_SIZE - 1): parent_idx = np.random.randint(0, PARENTS_COUNT) parent = prev_population[parent_idx][0] net = mutate_net(parent, use_cuda) fitness = evaluate(parent, netD) population.append((net, fitness)) netG = elite[0] for d_step in range(D_STEPS): # train discriminator generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for k_step in range(K_STEPS): loss = netD.dtrain(disloader) print( 'D_step {}, K-step {} adversarial discriminator training loss: {}' .format(d_step + 1, k_step + 1, loss)) generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} adversarial generator val loss: {}'.format( epoch + 1, loss))
def train(config): gpu_manage(config) train_dataset = Dataset(config.train_dir) val_dataset = Dataset(config.val_dir) training_data_loader = DataLoader(dataset=train_dataset, num_workers=config.threads, batch_size=config.batchsize, shuffle=True) val_data_loader = DataLoader(dataset=val_dataset, num_workers=config.threads, batch_size=config.test_batchsize, shuffle=False) gen = UNet(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids) if config.gen_init is not None: param = torch.load(config.gen_init) gen.load_state_dict(param) print('load {} as pretrained model'.format(config.gen_init)) dis = Discriminator(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids) if config.dis_init is not None: param = torch.load(config.dis_init) dis.load_state_dict(param) print('load {} as pretrained model'.format(config.dis_init)) opt_gen = optim.Adam(gen.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001) opt_dis = optim.Adam(dis.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001) real_a = torch.FloatTensor(config.batchsize, config.in_ch, 256, 256) real_b = torch.FloatTensor(config.batchsize, config.out_ch, 256, 256) criterionL1 = nn.L1Loss() criterionMSE = nn.MSELoss() criterionSoftplus = nn.Softplus() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if config.cuda: gen = gen.cuda(0) dis = dis.cuda(0) criterionL1 = criterionL1.cuda(0) criterionMSE = criterionMSE.cuda(0) criterionSoftplus = criterionSoftplus.cuda(0) real_a = real_a.cuda(0) real_b = real_b.cuda(0) real_a = Variable(real_a) real_b = Variable(real_b) logreport = LogReport(log_dir=config.out_dir) testreport = TestReport(log_dir=config.out_dir) for epoch in range(1, config.epoch + 1): print('Epoch', epoch, datetime.now()) for iteration, batch in enumerate(tqdm(training_data_loader)): real_a, real_b = batch[0], batch[1] real_a = F.interpolate(real_a, size=256).to(device) real_b = F.interpolate(real_b, size=256).to(device) fake_b = gen.forward(real_a) # Update D opt_dis.zero_grad() fake_ab = torch.cat((real_a, fake_b), 1) pred_fake = dis.forward(fake_ab.detach()) batchsize, _, w, h = pred_fake.size() real_ab = torch.cat((real_a, real_b), 1) pred_real = dis.forward(real_ab) loss_d_fake = torch.sum(criterionSoftplus(pred_fake)) / batchsize / w / h loss_d_real = torch.sum(criterionSoftplus(-pred_real)) / batchsize / w / h loss_d = loss_d_fake + loss_d_real loss_d.backward() if epoch % config.minimax == 0: opt_dis.step() # Update G opt_gen.zero_grad() fake_ab = torch.cat((real_a, fake_b), 1) pred_fake = dis.forward(fake_ab) loss_g_gan = torch.sum(criterionSoftplus(-pred_fake)) / batchsize / w / h loss_g = loss_g_gan + criterionL1(fake_b, real_b) * config.lamb loss_g.backward() opt_gen.step() if iteration % 100 == 0: logreport({ 'epoch': epoch, 'iteration': len(training_data_loader) * (epoch - 1) + iteration, 'gen/loss': loss_g.item(), 'dis/loss': loss_d.item(), }) with torch.no_grad(): log_test = test(config, val_data_loader, gen, criterionMSE, epoch) testreport(log_test) if epoch % config.snapshot_interval == 0: checkpoint(config, epoch, gen, dis) logreport.save_lossgraph() testreport.save_lossgraph() print('Done', datetime.now())