def test_train_steps(self): real_batch = common.load_images(self.N, size=self.H) # Setup optimizers optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) # Log statistics to check log_data = metric_log.MetricLog() # Test D train step log_data = self.netD.train_step(real_batch=real_batch, netG=self.netG, optD=optD, device='cpu', log_data=log_data) log_data = self.netG.train_step(real_batch=real_batch, netD=self.netD, optG=optG, log_data=log_data, device='cpu') for name, metric_dict in log_data.items(): assert type(name) == str assert type(metric_dict['value']) == float
def test_no_decay(self): optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) lr_scheduler = scheduler.LRScheduler(lr_decay='None', optD=optD, optG=optG, num_steps=self.num_steps) log_data = metric_log.MetricLog() for step in range(1, self.num_steps + 1): lr_scheduler.step(log_data, step) assert (self.lr_D == self.get_lr(optD)) assert (self.lr_G == self.get_lr(optG))
def test_linear_decay(self): optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) lr_scheduler = scheduler.LRScheduler(lr_decay='linear', optD=optD, optG=optG, num_steps=self.num_steps) log_data = metric_log.MetricLog() for step in range(1, self.num_steps + 1): lr_scheduler.step(log_data, step) curr_lr = ((1 - step / self.num_steps) * self.lr_D) assert (curr_lr - self.get_lr(optD) < 1e-5) assert (curr_lr - self.get_lr(optG) < 1e-5)
def test_print_log(self): log_data = metric_log.MetricLog() global_step = 10 # Populate log data with some value for scalar in self.scalars: if scalar == 'img': continue log_data.add_metric(scalar, 1.0) printed = self.logger.print_log(global_step=global_step, log_data=log_data, time_taken=10) assert printed == ( 'INFO: [Epoch 1/1][Global Step: 10/100] ' + '\n| D(G(z)): 1.0\n| D(x): 1.0\n| errD: 1.0\n| errG: 1.0' + '\n| lr_D: 1.0\n| lr_G: 1.0\n| (10.0000 sec/idx)')
def train(self): """ Runs the training pipeline with all given parameters in Trainer. """ # Restore models global_step = self._restore_models_and_step() print("INFO: Starting training from global step {}...".format( global_step)) try: start_time = time.time() # Iterate through data iter_dataloader = iter(self.dataloader) while global_step < self.num_steps: log_data = metric_log.MetricLog() # log data for tensorboard # ------------------------- # One Training Step # ------------------------- # Update n_dis times for D for i in range(self.n_dis): iter_dataloader, real_batch = self._fetch_data( iter_dataloader=iter_dataloader) # ----------------------- # Update G Network # ----------------------- # Update G, but only once. if i == 0: log_data = self.netG.train_step( real_batch=real_batch, netD=self.netD, optG=self.optG, global_step=global_step, log_data=log_data, device=self.device) # ------------------------ # Update D Network # ----------------------- log_data = self.netD.train_step(real_batch=real_batch, netG=self.netG, optD=self.optD, log_data=log_data, global_step=global_step, device=self.device) # -------------------------------- # Update Training Variables # ------------------------------- global_step += 1 log_data = self.scheduler.step(log_data=log_data, global_step=global_step) # ------------------------- # Logging and Metrics # ------------------------- if global_step % self.log_steps == 0: self.logger.write_summaries(log_data=log_data, global_step=global_step) if global_step % self.print_steps == 0: curr_time = time.time() self.logger.print_log(global_step=global_step, log_data=log_data, time_taken=(curr_time - start_time) / self.print_steps) start_time = curr_time if global_step % self.vis_steps == 0: self.logger.vis_images(netG=self.netG, global_step=global_step) if global_step % self.save_steps == 0: print("INFO: Saving checkpoints...") self.netG.save_checkpoint(directory=self.netG_ckpt_dir, global_step=global_step, optimizer=self.optG) self.netD.save_checkpoint(directory=self.netD_ckpt_dir, global_step=global_step, optimizer=self.optD) # Save models at the very end of training if self.save_when_end: print("INFO: Saving final checkpoints...") self.netG.save_checkpoint(directory=self.netG_ckpt_dir, global_step=global_step, optimizer=self.optG) self.netD.save_checkpoint(directory=self.netD_ckpt_dir, global_step=global_step, optimizer=self.optD) except KeyboardInterrupt: print("INFO: Saving checkpoints from keyboard interrupt...") self.netG.save_checkpoint(directory=self.netG_ckpt_dir, global_step=global_step, optimizer=self.optG) self.netD.save_checkpoint(directory=self.netD_ckpt_dir, global_step=global_step, optimizer=self.optD) finally: self.logger.close_writers() print("INFO: Training Ended.")
def train(self, upload_path=None): """ Runs the training pipeline with all given parameters in Trainer. """ # Restore models global_step = self._restore_models_and_step() print("INFO: Starting training from global step {}...".format( global_step)) try: start_time = time.time() # Iterate through data iter_dataloader = iter(self.dataloader) while global_step < self.num_steps: log_data = metric_log.MetricLog() # log data for tensorboard # ------------------------- # One Training Step # ------------------------- # Update n_dis times for D for i in range(self.n_dis): iter_dataloader, real_batch = self._fetch_data( iter_dataloader=iter_dataloader) # ------------------------ # Update D Network # ----------------------- log_data = self.netD.train_step( real_batch=real_batch, netG=self.netG, optD=self.optD, log_data=log_data, global_step=global_step, device=self.device, ) # ----------------------- # Update G Network # ----------------------- # Update G, but only once. if i == (self.n_dis - 1): log_data = self.netG.train_step( real_batch=real_batch, netD=self.netD, optG=self.optG, global_step=global_step, log_data=log_data, device=self.device, ) # -------------------------------- # Update Training Variables # ------------------------------- global_step += 1 log_data = self.scheduler.step(log_data=log_data, global_step=global_step) # ------------------------- # Logging and Metrics # ------------------------- if global_step % self.log_steps == 0: self.logger.write_summaries(log_data=log_data, global_step=global_step) if global_step % self.print_steps == 0: curr_time = time.time() self.logger.print_log( global_step=global_step, log_data=log_data, time_taken=(curr_time - start_time) / self.print_steps, ) start_time = curr_time if global_step % self.vis_steps == 0: self.logger.vis_images(netG=self.netG, global_step=global_step) if global_step % self.save_steps == 0: print("INFO: Saving checkpoints...") self._save_model_checkpoints(global_step) if upload_path is not None: call_string = "rsync --update -arq " + "'" + self.log_dir + "'" + " " + "'" + upload_path + "'" result = subprocess.call(call_string, shell=True) if result is 0: print("INFO: Upload checkpoints SUCCESSFUL") elif result is 1: print("INFO: Upload checkpoints FAILED") print("INFO: Saving final checkpoints...") self._save_model_checkpoints(global_step) except KeyboardInterrupt: print("INFO: Saving checkpoints from keyboard interrupt...") self._save_model_checkpoints(global_step) finally: self.logger.close_writers() print("INFO: Training Ended.")
def train(self): """ Runs the training pipeline with all given parameters in Trainer. """ # Restore models global_step = self._restore_models_and_step() print("INFO: Starting training from global step {}...".format( global_step)) try: start_time = time.time() # Iterate through data iter_dataloader = iter(self.dataloader) while global_step < self.num_steps: log_data = metric_log.MetricLog() # log data for tensorboard # ------------------------- # One Training Step # ------------------------- # Update n_dis times for D #R=np.random.randint(18,22,1) for i in range(self.n_dis): iter_dataloader, real_batch = self._fetch_data( iter_dataloader=iter_dataloader) # ------------------------ # Update D Network # ----------------------- ''' normal training log_data = self.netD.train_step(real_batch=real_batch, netG=self.netG, optD=self.optD, log_data=log_data, global_step=global_step, #radius=R, device=self.device) ''' #DAT log_data = self.netD.advtrain_step(real_batch=real_batch, netG=self.netG, optD=self.optD, log_data=log_data, global_step=global_step, device=self.device) # ----------------------- # Update G Network # ----------------------- # Update G, but only once. if i == (self.n_dis - 1): log_data = self.netG.train_step( real_batch=real_batch, netD=self.netD, optG=self.optG, global_step=global_step, log_data=log_data, #radius=R, device=self.device) # -------------------------------- # Update Training Variables # ------------------------------- global_step += 1 log_data = self.scheduler.step(log_data=log_data, global_step=global_step) # ------------------------- # Logging and Metrics # ------------------------- if global_step % self.log_steps == 0: self.logger.write_summaries(log_data=log_data, global_step=global_step) if global_step % self.print_steps == 0: curr_time = time.time() self.logger.print_log(global_step=global_step, log_data=log_data, time_taken=(curr_time - start_time) / self.print_steps) start_time = curr_time if global_step % self.vis_steps == 0: self.logger.vis_images(netG=self.netG, global_step=global_step) self.logger.summary_fid(netG=self.netG, dataset=self.dataset, global_step=global_step) self.logger.summary_IS(netG=self.netG, global_step=global_step) self.logger.summary_KID(netG=self.netG, dataset=self.dataset, global_step=global_step) if global_step % self.save_steps == 0: print("INFO: Saving checkpoints...") self._save_model_checkpoints(global_step) print("INFO: Saving final checkpoints...") self._save_model_checkpoints(global_step) except KeyboardInterrupt: print("INFO: Saving checkpoints from keyboard interrupt...") self._save_model_checkpoints(global_step) finally: self.logger.close_writers() print("INFO: Training Ended.")
def train(self): """ Runs the training pipeline with all given parameters in Trainer. """ # Restore models global_step = self._restore_models_and_step() if self.gold and global_step >= self.gold_step: self.netD.use_gold = True print("INFO: Starting training from global step {}...".format( global_step)) logit_save_num = 0 self.logit_results = defaultdict(dict) try: start_time = time.time() # Mixed precision if self.amp: print("INFO: Using mixed precision training...") scaler = torch.cuda.amp.GradScaler() else: scaler = None # Iterate through data iter_dataloader = iter(self.dataloader) if self.train_drs: iter_dataloader_drs = iter(self.dataloader_drs) while global_step < self.num_steps: log_data = metric_log.MetricLog() # log data for tensorboard if self.topk: self.netG.decay_topk_rate(global_step, epoch_steps=len(self.dataloader)) if self.gold and global_step == self.gold_step: self.netD.use_gold = True # ------------------------- # One Training Step # ------------------------- # Update n_dis times for D for i in range(self.n_dis): iter_dataloader, real_batch = self._fetch_data( iter_dataloader=iter_dataloader) # ------------------------ # Update D Network # ----------------------- log_data = self.netD.train_step( real_batch=real_batch, netG=self.netG, optD=self.optD, log_data=log_data, global_step=global_step, device=self.device, scaler=scaler) # train netD2 for DRS if self.train_drs: iter_dataloader_drs, real_batch_drs = self._fetch_data( iter_dataloader=iter_dataloader_drs) log_data = self.netD_drs.train_step( real_batch=real_batch_drs, netG=self.netG, optD=self.optD_drs, log_data=log_data, global_step=global_step, device=self.device, scaler=scaler) # ----------------------- # Update G Network # ----------------------- # Update G, but only once. if i == (self.n_dis - 1): log_data = self.netG.train_step( real_batch=real_batch, netD=self.netD, optG=self.optG, global_step=global_step, log_data=log_data, device=self.device, scaler=scaler) # -------------------------------- # Update Training Variables # ------------------------------- global_step += 1 log_data = self.scheduler.step(log_data=log_data, global_step=global_step) # ------------------------- # Logging and Metrics # ------------------------- if global_step % self.log_steps == 0: self.logger.write_summaries(log_data=log_data, global_step=global_step) if global_step % self.print_steps == 0: curr_time = time.time() topk_rate = self.netG.topk_rate if hasattr(self.netG, 'topk_rate') else 1 log_data.add_metric(f'topk_rate', topk_rate, group='topk_rate', precision=6) self.logger.print_log(global_step=global_step, log_data=log_data, time_taken=(curr_time - start_time) / self.print_steps) start_time = curr_time if global_step % self.vis_steps == 0: if 'gaussian' in self.log_dir: plot_gaussian_samples(netG=self.netG, global_step=global_step, log_dir=self.log_dir, device=self.device) else: self.logger.vis_images(netG=self.netG, global_step=global_step) if self.save_logits and global_step % self.logit_save_steps == 0 and global_step >= self.save_logit_after and global_step <= self.stop_save_logit_after: if self.train_drs: netD = self.netD_drs netD_name = 'netD_drs' else: netD = self.netD netD_name = 'netD' mode = 'eval' if self.save_eval_logits else 'train' print(f"INFO: logit saving {mode} netD: {netD_name}...") logit_list = self._get_logit(netD=netD, eval_mode=mode=='eval') self.logit_results[f'{netD_name}_{mode}'][global_step] = logit_list logit_save_num += 1 if global_step % self.save_steps == 0: print("INFO: Saving checkpoints...") self._save_model_checkpoints(global_step) if self.save_logits and global_step >= self.save_logit_after: self._save_logit(self.logit_results) print("INFO: Saving final checkpoints...") self._save_model_checkpoints(global_step) if self.save_logits and global_step >= self.save_logit_after: self._save_logit(self.logit_results) except KeyboardInterrupt: print("INFO: Saving checkpoints from keyboard interrupt...") self._save_model_checkpoints(global_step) if self.save_logits and global_step >= self.save_logit_after: self._save_logit(self.logit_results) finally: self.logger.close_writers() print("INFO: Training Ended.")
def setup(self): self.log_data = metric_log.MetricLog()