def run_pretrain(self, load=False): if not load: global_iteration = 0 training_steps = int(self.n_samples /self.opt.batchSize) self.num_semantics = self.progressive_model.num_semantics self.set_data_resolution(int(self.opt.crop_size/self.opt.aspect_ratio)) print(f"Training at resolution {self.progressive_model.generator.res}") dim_ind = 0 phase = "stabilize" scaling = int(self.opt.crop_size / (self.opt.aspect_ratio * self.progressive_model.generator.res)) num_epochs=1 for epoch in range(num_epochs): for iteration in range(training_steps): seg, _, im, _ = self.next_batch() seg, seg_mc, im = self.call_next_batch(seg,im) D_losses = self.step_discriminator( iteration, global_iteration, dim_ind, seg_mc, seg, im, scaling, phase) global_iteration += 1 if (iteration + 1) % 10 == 0: print( f"Res {self.progressive_model.generator.res:03d}, {phase.rjust(9)}: Iteration {iteration + 1:05d}/{training_steps:05d}, epoch:{epoch + 1:05d}/{num_epochs:05d}" ) if epoch % self.opt.save_epoch_freq == 0 or \ (epoch+1) == num_epochs: util.save_network(self.end2end_model_on_one_gpu.netD2, 'D2', global_iteration, self.opt) else: netD2 = self.end2end_model_on_one_gpu.netD2 netD2 = util.load_network(netD2, 'D2', self.opt.which_iter_D2, self.opt) global_iteration = self.opt.which_iter_D2 return global_iteration
def run(self): iteration_D2 = self.run_pretrain(load=not self.opt.pretrain_D2) self.optimizer_D2 = self.end2end_model_on_one_gpu.create_optimizers( lr=self.opt.lr) fixed_z = torch.randn(self.opt.batchSize, 512) self.num_semantics = self.progressive_model.num_semantics global_iteration = self.progressive_model.global_iteration training_steps = int(self.n_samples / self.opt.batchSize) self.set_data_resolution( int(self.opt.crop_size / self.opt.aspect_ratio)) print(f"Training at resolution {self.progressive_model.generator.res}") dim_ind = 0 phase = "stabilize" scaling = int( self.opt.crop_size / (self.opt.aspect_ratio * self.progressive_model.generator.res)) upsample = nn.Upsample(scale_factor=scaling, mode='nearest') z_fid = torch.randn(self.opt.nums_fid, 512).cuda() epoch_start = 0 iter_counter = IterationCounter(self.opt, len(self.dataset)) self.old_lr = self.opt.lr self.opt.epochs = self.opt.niter + self.opt.niter_decay if self.opt.BN_eval: for module in self.pix2pix_model.modules(): if "BATCHNORM" in module.__class__.__name__.upper(): print(module.__class__.__name__) module.eval() for epoch in range(self.opt.epochs): if epoch % self.opt.eval_freq == 0 or \ epoch == self.opt.epochs: if not self.opt.BN_eval: self.pix2pix_model.eval() fid = self.compute_FID( global_iteration, z_fixed=z_fid, real_fake='fake') #real_fake='real'/fake self.progressive_model.writer.add_scalar( "fid_fake", fid, global_iteration, ) fid = self.compute_FID( global_iteration, z_fixed=z_fid, real_fake='real') #real_fake='real'/fake self.progressive_model.writer.add_scalar( "fid_real", fid, global_iteration, ) if not self.opt.BN_eval: self.pix2pix_model.train() iter_counter.record_epoch_start(epoch) for iteration in np.arange(training_steps): iter_counter.record_one_iteration() seg, _, im, _ = self.next_batch() seg, seg_mc, im = self.call_next_batch(seg, im) G_losses = self.step_generator_end2end(iteration, global_iteration, dim_ind, seg_mc, seg, im, scaling, phase) D_losses = self.step_discriminator_end2end( iteration, global_iteration, dim_ind, seg_mc, seg, im, scaling, phase) # print('disc', time.time()-t3) global_iteration += 1 if (iteration + 1) % 100 == 0: alpha = (iteration / self.progressive_model.steps_per_phase[dim_ind] if phase == "fade" else None) fake_seg, fake_im_f, fake_im_r = self.end2end_model( iteration, global_iteration, dim_ind, fixed_z, seg_mc.cpu(), seg.cpu(), im.cpu(), scaling, mode='inference') grid = make_grid(fake_seg, nrow=4, normalize=True, range=(-1, 1)) self.progressive_model.writer.add_image( "fake", grid, global_iteration) fake_im_f = fake_im_f.cpu() grid = make_grid(fake_im_f, nrow=4, normalize=True, range=(-1, 1)) self.progressive_model.writer.add_image( "fake_im_ff", grid, global_iteration) fake_im_r = fake_im_r.cpu() grid = make_grid(fake_im_r, nrow=4, normalize=True, range=(-1, 1)) self.progressive_model.writer.add_image( "fake_im_fr", grid, global_iteration) im = im.cpu() grid = make_grid(im, nrow=4, normalize=True, range=(-1, 1)) self.progressive_model.writer.add_image( "im_real", grid, global_iteration) seg = seg.cpu() im_ = self.progressive_model.color_transfer(seg) grid = make_grid(im_, nrow=4, normalize=True, range=(-1, 1)) self.progressive_model.writer.add_image( "seg", grid, global_iteration) if (iteration + 1) % 100 == 0: print( f"Res {self.progressive_model.generator.res:03d}, {phase.rjust(9)}: Iteration {iteration + 1:05d}/{training_steps:05d}, epoch:{epoch + 1:05d}/{self.opt.epochs:05d}" ) if epoch % self.opt.save_epoch_freq == 0 or epoch == self.opt.epochs: self.progressive_model.save_model(self.num_semantics, global_iteration, phase) self.pix2pix_model.save( str(int(epoch + 1) + int(self.opt.which_epoch))) util.save_network(self.end2end_model_on_one_gpu.netD2, 'D2', global_iteration + iteration_D2, self.opt) self.update_learning_rate(epoch) iter_counter.record_epoch_end()
def save(self, epoch): util.save_network(self.netG, 'G', epoch, self.opt) util.save_network(self.netD, 'D', epoch, self.opt) if self.opt.use_vae: util.save_network(self.netE, 'E', epoch, self.opt)