def __init__(self, config): super().__init__(config) # define models ( generator and discriminator) self.netG = Generator(self.config) self.netD = Discriminator(self.config) # define dataloader self.dataloader = CelebADataLoader(self.config) # define loss self.loss = BinaryCrossEntropy() # define optimizers for both generator and discriminator self.optimG = torch.optim.Adam(self.netG.parameters(), lr=self.config.learning_rate, betas=(self.config.beta1, self.config.beta2)) self.optimD = torch.optim.Adam(self.netD.parameters(), lr=self.config.learning_rate, betas=(self.config.beta1, self.config.beta2)) # initialize counter self.current_epoch = 0 self.current_iteration = 0 self.best_valid_mean_iou = 0 self.fixed_noise = Variable( torch.randn(self.config.batch_size, self.config.g_input_size, 1, 1)) self.real_label = 1 self.fake_label = 0 # set cuda flag self.is_cuda = torch.cuda.is_available() if self.is_cuda and not self.config.cuda: self.logger.info( "WARNING: You have a CUDA device, so you should probably enable CUDA" ) self.cuda = self.is_cuda & self.config.cuda # set the manual seed for torch #if not self.config.seed: self.manual_seed = random.randint(1, 10000) #self.manual_seed = self.config.seed self.logger.info("seed: ", self.manual_seed) random.seed(self.manual_seed) if self.cuda: self.device = torch.device("cuda") torch.cuda.set_device(self.config.gpu_device) torch.cuda.manual_seed_all(self.manual_seed) self.logger.info("Program will run on *****GPU-CUDA***** ") print_cuda_statistics() else: self.device = torch.device("cpu") torch.manual_seed(self.manual_seed) self.logger.info("Program will run on *****CPU***** ") self.netG = self.netG.to(self.device) self.netD = self.netD.to(self.device) self.loss = self.loss.to(self.device) self.fixed_noise = self.fixed_noise.to(self.device) # Model Loading from the latest checkpoint if not found start from scratch. self.load_checkpoint(self.config.checkpoint_file) # Summary Writer self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir, comment='DCGAN')
class DCGANAgent(BaseAgent): def __init__(self, config): super().__init__(config) # define models ( generator and discriminator) self.netG = Generator(self.config) self.netD = Discriminator(self.config) # define dataloader self.dataloader = CelebADataLoader(self.config) # define loss self.loss = BinaryCrossEntropy() # define optimizers for both generator and discriminator self.optimG = torch.optim.Adam(self.netG.parameters(), lr=self.config.learning_rate, betas=(self.config.beta1, self.config.beta2)) self.optimD = torch.optim.Adam(self.netD.parameters(), lr=self.config.learning_rate, betas=(self.config.beta1, self.config.beta2)) # initialize counter self.current_epoch = 0 self.current_iteration = 0 self.best_valid_mean_iou = 0 self.fixed_noise = Variable( torch.randn(self.config.batch_size, self.config.g_input_size, 1, 1)) self.real_label = 1 self.fake_label = 0 # set cuda flag self.is_cuda = torch.cuda.is_available() if self.is_cuda and not self.config.cuda: self.logger.info( "WARNING: You have a CUDA device, so you should probably enable CUDA" ) self.cuda = self.is_cuda & self.config.cuda # set the manual seed for torch #if not self.config.seed: self.manual_seed = random.randint(1, 10000) #self.manual_seed = self.config.seed self.logger.info("seed: ", self.manual_seed) random.seed(self.manual_seed) if self.cuda: self.device = torch.device("cuda") torch.cuda.set_device(self.config.gpu_device) torch.cuda.manual_seed_all(self.manual_seed) self.logger.info("Program will run on *****GPU-CUDA***** ") print_cuda_statistics() else: self.device = torch.device("cpu") torch.manual_seed(self.manual_seed) self.logger.info("Program will run on *****CPU***** ") self.netG = self.netG.to(self.device) self.netD = self.netD.to(self.device) self.loss = self.loss.to(self.device) self.fixed_noise = self.fixed_noise.to(self.device) # Model Loading from the latest checkpoint if not found start from scratch. self.load_checkpoint(self.config.checkpoint_file) # Summary Writer self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir, comment='DCGAN') def load_checkpoint(self, file_name): filename = self.config.checkpoint_dir + file_name try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.netG.load_state_dict(checkpoint['G_state_dict']) self.optimG.load_state_dict(checkpoint['G_optimizer']) self.netD.load_state_dict(checkpoint['D_state_dict']) self.optimD.load_state_dict(checkpoint['D_optimizer']) self.fixed_noise = checkpoint['fixed_noise'] self.manual_seed = checkpoint['manual_seed'] self.logger.info( "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info( "No checkpoint exists from '{}'. Skipping...".format( self.config.checkpoint_dir)) self.logger.info("**First time to train**") def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best=0): state = { 'epoch': self.current_epoch, 'iteration': self.current_iteration, 'G_state_dict': self.netG.state_dict(), 'G_optimizer': self.optimG.state_dict(), 'D_state_dict': self.netD.state_dict(), 'D_optimizer': self.optimD.state_dict(), 'fixed_noise': self.fixed_noise, 'manual_seed': self.manual_seed } # Save the state torch.save(state, self.config.checkpoint_dir + file_name) # If it is the best copy it to another file 'model_best.pth.tar' if is_best: shutil.copyfile(self.config.checkpoint_dir + file_name, self.config.checkpoint_dir + 'model_best.pth.tar') def run(self): """ This function will the operator :return: """ try: self.train() except KeyboardInterrupt: self.logger.info("You have entered CTRL+C.. Wait to finalize") def train(self): for epoch in range(self.current_epoch, self.config.max_epoch): self.current_epoch = epoch self.train_one_epoch() self.save_checkpoint() def train_one_epoch(self): # initialize tqdm batch tqdm_batch = tqdm(self.dataloader.loader, total=self.dataloader.num_iterations, desc="epoch-{}-".format(self.current_epoch)) self.netG.train() self.netD.train() epoch_lossG = AverageMeter() epoch_lossD = AverageMeter() for curr_it, x in enumerate(tqdm_batch): #y = torch.full((self.batch_size,), self.real_label) x = x[0] y = torch.randn(x.size(0), ) fake_noise = torch.randn(x.size(0), self.config.g_input_size, 1, 1) if self.cuda: x = x.cuda(self.config.async_loading) y = y.cuda(self.config.async_loading) fake_noise = fake_noise.cuda(self.config.async_loading) x = Variable(x) y = Variable(y) fake_noise = Variable(fake_noise) #################### # Update D network: maximize log(D(x)) + log(1 - D(G(z))) # train with real self.netD.zero_grad() D_real_out = self.netD(x) y.fill_(self.real_label) loss_D_real = self.loss(D_real_out, y) loss_D_real.backward() # train with fake G_fake_out = self.netG(fake_noise) y.fill_(self.fake_label) D_fake_out = self.netD(G_fake_out.detach()) loss_D_fake = self.loss(D_fake_out, y) loss_D_fake.backward() #D_mean_fake_out = D_fake_out.mean().item() loss_D = loss_D_fake + loss_D_real self.optimD.step() #################### # Update G network: maximize log(D(G(z))) self.netG.zero_grad() y.fill_(self.real_label) D_out = self.netD(G_fake_out) loss_G = self.loss(D_out, y) loss_G.backward() #D_G_mean_out = D_out.mean().item() self.optimG.step() epoch_lossD.update(loss_D.item()) epoch_lossG.update(loss_G.item()) self.current_iteration += 1 self.summary_writer.add_scalar("epoch/Generator_loss", epoch_lossG.val, self.current_iteration) self.summary_writer.add_scalar("epoch/Discriminator_loss", epoch_lossD.val, self.current_iteration) gen_out = self.netG(self.fixed_noise) out_img = self.dataloader.plot_samples_per_epoch( gen_out.data, self.current_iteration) self.summary_writer.add_image('train/generated_image', out_img, self.current_iteration) tqdm_batch.close() self.logger.info("Training at epoch-" + str(self.current_epoch) + " | " + "Discriminator loss: " + str(epoch_lossD.val) + " - Generator Loss-: " + str(epoch_lossG.val)) def validate(self): pass def finalize(self): """ Finalize all the operations of the 2 Main classes of the process the operator and the data loader :return: """ self.logger.info( "Please wait while finalizing the operation.. Thank you") self.save_checkpoint() self.summary_writer.export_scalars_to_json("{}all_scalars.json".format( self.config.summary_dir)) self.summary_writer.close() self.dataloader.finalize()