def _grow_phase(self): # Log transition value here to not create misguiding representation on # tensorboard if self.transition_value is not None: logger.log_variable("stats/transition-value", self.get_transition_value()) self._update_transition_value() transition_iters = self.transition_iters minibatch_repeats = self.cfg.trainer.progressive.minibatch_repeats next_transition = self.prev_transition + transition_iters num_batches = (next_transition - self.global_step) / self.batch_size() num_batches = int(np.ceil(num_batches)) num_repeats = int(np.ceil(num_batches / minibatch_repeats)) logger.info( f"Starting grow phase for imsize={self.current_imsize()}" + f" Training for {num_batches} batches with batch size: {self.batch_size()}" ) for it in range(num_repeats): for _ in range( min(minibatch_repeats, num_batches - it * minibatch_repeats)): self.train_step() self._update_transition_value() # Check that grow phase happens at correct spot assert self.global_step >= self.prev_transition + transition_iters,\ f"Global step: {self.global_step}, batch size: {self.batch_size()}, prev_transition: {self.prev_transition}" +\ f" transition iters: {transition_iters}" assert self.global_step - self.batch_size() <= self.prev_transition + transition_iters,\ f"Global step: {self.global_step}, batch size: {self.batch_size()}, prev_transition: {self.prev_transition}" +\ f" transition iters: {transition_iters}"
def init_optimizer(self): self.loss_optimizer = loss.LossOptimizer.build_from_cfg( self.cfg, self.discriminator, self.generator) self.generator, self.discriminator = self.loss_optimizer.initialize_amp( ) logger.log_variable("stats/learning_rate", self.loss_optimizer._learning_rate)
def _update_transition_value(self): if self._get_phase() == "stability": self.transition_value = 1.0 else: remaining = self.global_step - self.prev_transition v = remaining / self.transition_iters assert 0 <= v <= 1 self.transition_value = v self.generator.update_transition_value(self.transition_value) self.discriminator.update_transition_value(self.transition_value) self.RA_generator.update_transition_value(self.transition_value) logger.log_variable("stats/transition-value", self.get_transition_value())
def calculate_fid(self): logger.info("Starting calculation of FID value") generator = self.trainer.RA_generator real_images, fake_images = infer.infer_images( self.trainer.dataloader_val, generator, truncation_level=0 ) """ # Remove FID calculation as holy shit this is expensive. cfg = self.trainer.cfg identifier = f"{cfg.dataset_type}_{cfg.data_val.dataset.percentage}_{self.current_imsize()}" transition_value = self.trainer.RA_generator.transition_value fid_val = metric_api.fid( real_images, fake_images, batch_size=self.fid_batch_size) logger.log_variable("stats/fid", np.mean(fid_val), log_level=logging.INFO) """ l1 = metric_api.l1(real_images, fake_images) l2 = metric_api.l1(real_images, fake_images) psnr = metric_api.psnr(real_images, fake_images) lpips = metric_api.lpips( real_images, fake_images, self.lpips_batch_size) logger.log_variable("stats/l1", l1, log_level=logging.INFO) logger.log_variable("stats/l2", l2, log_level=logging.INFO) logger.log_variable("stats/psnr", psnr, log_level=logging.INFO) logger.log_variable("stats/lpips", lpips, log_level=logging.INFO)
def init_models(self): self.discriminator = models.build_discriminator( self.cfg, data_parallel=torch.cuda.device_count() > 1) self.generator = models.build_generator( self.cfg, data_parallel=torch.cuda.device_count() > 1) self.RA_generator = models.build_generator( self.cfg, data_parallel=torch.cuda.device_count() > 1) self.RA_generator = torch_utils.to_cuda(self.RA_generator) self.RA_generator.load_state_dict(self.generator.state_dict()) logger.info(str(self.generator)) logger.info(str(self.discriminator)) logger.log_variable( "stats/discriminator_parameters", torch_utils.number_of_parameters(self.discriminator)) logger.log_variable("stats/generator_parameters", torch_utils.number_of_parameters(self.generator))
def update_beta(self): batch_size = self.trainer.batch_size() g = self.trainer.RA_generator g.update_beta(batch_size) logger.log_variable("stats/running_average_decay", g.ra_beta)