def __init__(self, opt, resume_epoch=0): self.opt = opt self.pix2pix_model = Pix2PixModel(opt) if len(opt.gpu_ids) > 1: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model.to(opt.gpu_ids[0]) self.pix2pix_model_on_one_gpu = self.pix2pix_model if opt.use_ema: self.netG_ema = EMA(opt.ema_beta) for name, param in self.pix2pix_model_on_one_gpu.net[ 'netG'].named_parameters(): if param.requires_grad: self.netG_ema.register(name, param.data) self.netCorr_ema = EMA(opt.ema_beta) for name, param in self.pix2pix_model_on_one_gpu.net[ 'netCorr'].named_parameters(): if param.requires_grad: self.netCorr_ema.register(name, param.data) self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr if opt.continue_train and opt.which_epoch == 'latest': checkpoint = torch.load( os.path.join(opt.checkpoints_dir, opt.name, 'optimizer.pth')) self.optimizer_G.load_state_dict(checkpoint['G']) self.optimizer_D.load_state_dict(checkpoint['D']) self.last_data, self.last_netCorr, self.last_netG, self.last_optimizer_G = None, None, None, None
def initialize_networks(self, opt): self.netG = networks.define_G(opt) self.netD = networks.define_D(opt) # set require gradients if self.isTrain: self.set_requires_grad([self.netG, self.netD], True) else: self.set_requires_grad([self.netG, self.netD], False) if self.use_gpu: self.netG = DataParallelWithCallback(self.netG, device_ids=opt['gpu_ids']) self.netD = DataParallelWithCallback(self.netD, device_ids=opt['gpu_ids']) self.train_nets = [self.netG, self.netD]
def __init__(self, opt): self.opt = opt self.pix2pix_model = Pix2PixModel(opt) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model self.generated = None if opt.isTrain: if not opt.unpairTrain: ( self.optimizer_G, self.optimizer_D, ) = self.pix2pix_model_on_one_gpu.create_optimizers(opt) else: ( self.optimizer_G, self.optimizer_D, self.optimizer_D2, ) = self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr self.d_losses = {} self.nanCount = 0
def initialize_networks(self, opt): self.netGA = networks.define_G(opt, opt['netGA']) self.netGB = networks.define_G(opt, opt['netGB']) self.netDA = networks.define_D(opt, opt['netDA']) self.netDB = networks.define_D(opt, opt['netDB']) self.netEA, self.netHairA = networks.define_RES( opt, opt['input_nc_A'], opt['netEDA']) self.netEB, self.netHairB = networks.define_RES( opt, opt['input_nc_B'], opt['netEDB']) if self.opt['pretrain']: self.train_nets = [ self.netGA, self.netGB, self.netDA, self.netDB, self.netEA, self.netHairA, self.netEB, self.netHairB ] else: self.train_nets = [self.netEA, self.netHairA] # set require gradients if self.isTrain: self.set_requires_grad(self.train_nets, True) else: self.set_requires_grad(self.train_nets, False) if self.use_gpu: for i in range(len(self.train_nets)): self.train_nets[i] = DataParallelWithCallback( self.train_nets[i], device_ids=opt['gpu_ids']) if self.opt['pretrain']: self.netGA, self.netGB, self.netDA, self.netDB, self.netEA, \ self.netHairA, self.netEB, self.netHairB = self.train_nets else: self.netEA, self.netHairA = self.train_nets
def __init__(self, opt): self.opt = opt self.pix2pix_model = create_model(opt) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids, output_device=opt.gpu_ids[-1], chunk_size=opt.chunk_size) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model # self.Render = networks.Render(opt, render_size=opt.crop_size) self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr
def __init__(self, opt): self.opt = opt self.pix2pix_model = Pix2PixModel(opt) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model.cuda() self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr
def __init__(self, opt): self.opt = opt if self.opt.model == 'pix2pix': self.pix2pix_model = Pix2pixModel(opt) elif self.opt.model == 'smis': self.pix2pix_model = SmisModel(opt) print(self.pix2pix_model) with open(os.path.join(opt.checkpoints_dir, opt.name, 'model.txt'), 'w') as f: f.write(self.pix2pix_model.__str__()) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr
def __init__(self, opt): self.opt = opt self.pix2pix_model = Pix2PixModel(opt) #self.pix2pix_model = torch.nn.parallel.DistributedDataParallel(self.pix2pix_model,device_ids=[opt.gpu], find_unused_parameters=True) self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = self.pix2pix_model_on_one_gpu.create_optimizers( opt) self.old_lr = opt.lr
def __init__(self, opt, model): super(MyModel, self).__init__() self.opt = opt model = model.cuda(opt.gpu_ids[0]) self.module = model self.model = DataParallelWithCallback(model, device_ids=opt.gpu_ids) if opt.batch_for_first_gpu != -1: self.bs_per_gpu = (opt.batchSize - opt.batch_for_first_gpu) // ( len(opt.gpu_ids) - 1) # batch size for each GPU else: self.bs_per_gpu = int( np.ceil(float(opt.batchSize) / len(opt.gpu_ids))) # batch size for each GPU self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batchSize
def __init__(self, opt): self.opt = opt self.seg_inpaint_model = SegInpaintModel(opt) if len(opt.gpu_ids) > 0: self.seg_inpaint_model = DataParallelWithCallback( self.seg_inpaint_model, device_ids=opt.gpu_ids) self.seg_inpaint_model_on_one_gpu = self.seg_inpaint_model.module else: self.seg_inpaint_model_on_one_gpu = self.seg_inpaint_model self.generated = None self.optimizer_SPNet, self.optimizer_SGNet, self.optimizer_D_seg, self.optimizer_D_img = \ self.seg_inpaint_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr
def __init__(self, opt, model): super(MyModel, self).__init__() self.opt = opt model = model.cuda(opt.gpu_ids[0]) self.module = model if opt.distributed: self.model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) else: #self.model = nn.DataParallel(model, device_ids=opt.gpu_ids) self.model = DataParallelWithCallback(model, device_ids=opt.gpu_ids) if opt.batch_for_first_gpu != -1: self.bs_per_gpu = (opt.batchSize - opt.batch_for_first_gpu) // (len(opt.gpu_ids) - 1) # batch size for each GPU else: self.bs_per_gpu = int(np.ceil(float(opt.batchSize) / len(opt.gpu_ids))) # batch size for each GPU self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batchSize
def __init__(self, opt): self.opt = opt self.pix2pix_model = Pix2PixModel(opt) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr self.amp = True if AMP and opt.use_amp else False if self.amp: self.scaler_G = GradScaler() self.scaler_D = GradScaler()
def __init__(self, opt): self.opt = opt if self.opt.dual: from models.pix2pix_dualmodel import Pix2PixModel elif self.opt.dual_segspade: from models.pix2pix_dual_segspademodel import Pix2PixModel elif opt.box_unpair: from models.pix2pix_dualunpair import Pix2PixModel else: from models.pix2pix_model import Pix2PixModel self.pix2pix_model = Pix2PixModel(opt) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr
class DxdyModel(BaseModel): def name(self): return 'DxdyModel' @staticmethod def modify_commandline_options(parser, is_train=True): pass def initialize(self, opt): BaseModel.initialize(self, opt) # set networks self.initialize_networks(opt) # set loss functions self.initialize_loss(opt) # set optimizer self.initialize_optimizer(opt) self.initialize_other(opt) self.model_dict = { 'netG': { 'model': self.netG.module if self.use_gpu else self.netG, 'optimizer': self.optimizer_G }, 'netD': { 'model': self.netD.module if self.use_gpu else self.netD, 'optimizer': self.optimizer_D } } self.opt = opt def initialize_networks(self, opt): self.netG = networks.define_G(opt) self.netD = networks.define_D(opt) # set require gradients if self.isTrain: self.set_requires_grad([self.netG, self.netD], True) else: self.set_requires_grad([self.netG, self.netD], False) if self.use_gpu: self.netG = DataParallelWithCallback(self.netG, device_ids=opt['gpu_ids']) self.netD = DataParallelWithCallback(self.netD, device_ids=opt['gpu_ids']) self.train_nets = [self.netG, self.netD] def initialize_optimizer(self, opt): G_params = list(self.netG.parameters()) D_params = list(self.netD.parameters()) beta1, beta2 = opt['beta1'], opt['beta2'] G_lr, D_lr = opt.get('lr_G', opt['lr']), opt.get('lr_D', opt['lr']) self.optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2)) self.optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) self.old_lr = opt['lr'] def initialize_loss(self, opt): self.criterionGAN = networks.GANLoss(opt['gan_mode'], tensor=self.FloatTensor, opt=opt) # if self.use_gpu: # self.criterionGAN = DataParallelWithCallback( # self.criterionGAN, device_ids=opt['gpu_ids']) self.criterionReg = torch.nn.L1Loss() def initialize_other(self, opt): full_body_mesh_vert_pos, full_body_mesh_face_inds = tools.load_body_mesh( ) self.full_body_mesh_vert_pos = full_body_mesh_vert_pos.unsqueeze(0) self.full_body_mesh_face_inds = full_body_mesh_face_inds.unsqueeze(0) sample_dataset = haya_data.Hair3D10KConvDataOnly() self.sample_loader = torch.utils.data.DataLoader( sample_dataset, batch_size=opt['batch_size'], shuffle=False, num_workers=opt['workers'], drop_last=True) self.sample_iter = iter(self.sample_loader) assert len(sample_dataset) > 0 print(f'{len(sample_dataset)} is loaded') def set_input(self, data): self.image = data['image'].to(self.device) self.mask = data['mask'].to(self.device) self.intensity = data['intensity'].to(self.device) self.gt_dxdy = data['dxdy'].to(torch.float).to(self.device) try: sample_data = next(self.sample_iter) except StopIteration: self.sample_iter = iter(self.sample_loader) sample_data = next(self.sample_iter) convdata = sample_data['convdata'].to(self.device) strands = convdata.permute( 0, 2, 3, 4, 1)[:, :3, :, :, :].contiguous() # b x 3 x 32 x 32 x 300 body_mesh_vert_pos = self.full_body_mesh_vert_pos.expand( strands.size(0), -1, -1).to(strands.device) body_mesh_face_inds = self.full_body_mesh_face_inds.expand( strands.size(0), -1, -1).to(strands.device) # generate random mvps mvps, _, _ = tools.generate_random_mvps(strands.size(0), strands.device) # render the 2D information self.strand_dxdy, self.strand_mask, body_mask, _, strand_vis, mvps, _ = tools.render( mvps, strands, body_mesh_vert_pos, body_mesh_face_inds, self.opt['im_size'], self.opt['expansion'], align_face=self.opt['align_face'], target_face_scale=self.opt['target_face_scale']) def forward(self): mask_ = self.mask.unsqueeze(1).type(self.image.dtype) strand_mask_ = self.strand_mask.unsqueeze(1).type( self.strand_dxdy.dtype) # for G self.pred_dxdy = self.netG(torch.cat([self.image, mask_], dim=1)) if self.pred_dxdy.size(-1) != mask_.size(-1): mask_ = torch.nn.functional.interpolate( mask_, size=self.pred_dxdy.shape[-2:], mode='nearest') fake_sample = self.pred_dxdy * mask_.type(self.pred_dxdy.dtype) self.g_fake_score = self.netD(fake_sample) # for D fake_sample = self.pred_dxdy.detach() * mask_ fake_sample.requires_grad_() real_sample = self.strand_dxdy * strand_mask_ self.d_real_score, self.d_fake_score = self.netD( real_sample), self.netD(fake_sample) # for vis self.mask_ = mask_ # scale the size of everything if self.pred_dxdy.size(-1) != self.mask.size(-1): self.mask = torch.nn.functional.interpolate( self.mask.unsqueeze(1), size=self.pred_dxdy.shape[-2:], mode='nearest').squeeze() self.intensity = torch.nn.functional.interpolate( self.intensity.unsqueeze(1), size=self.pred_dxdy.shape[-2:], mode='nearest').squeeze() self.gt_dxdy = torch.nn.functional.interpolate( self.gt_dxdy, size=self.pred_dxdy.shape[-2:], mode='nearest') def update_visuals(self): masked_pred_dxdy = torch.where(self.mask_ > 0., self.pred_dxdy, -torch.ones_like(self.pred_dxdy)) masked_gt_dxdy = torch.where(self.mask_ > 0., self.gt_dxdy, -torch.ones_like(self.gt_dxdy)) self.vis_dict['image'] = data_utils.make_grid_n(self.image[:6]) self.vis_dict['gt_dxdy'] = data_utils.vis_orient(self.gt_dxdy[:6]) self.vis_dict['pred_dxdy'] = data_utils.vis_orient(self.pred_dxdy[:6]) self.vis_dict['masked_pred_dxdy'] = data_utils.vis_orient( masked_pred_dxdy[:6]) self.vis_dict['masked_gt_dxdy'] = data_utils.vis_orient( masked_gt_dxdy[:6]) self.vis_dict['render_dxdy'] = data_utils.vis_orient( self.strand_dxdy[:6]) def backward_G(self): reg_loss = self.dxdy_reg_loss(self.pred_dxdy, self.gt_dxdy) * self.intensity reg_loss = (self.mask.expand_as(reg_loss).float() * reg_loss).mean() * self.opt.get('lambda_reg', 1.) g_loss = self.criterionGAN(self.g_fake_score, True, for_discriminator=False) sum([g_loss, reg_loss]).mean().backward() self.loss_dict['loss_reg'] = reg_loss.item() self.loss_dict['loss_g'] = g_loss.item() def backward_D(self): d_fake = self.criterionGAN(self.d_fake_score, False) d_real = self.criterionGAN(self.d_real_score, True) sum([d_fake, d_real]).mean().backward() self.loss_dict['loss_d_fake'] = d_fake.item() self.loss_dict['loss_d_real'] = d_real.item() def optimize_parameters(self): for net in self.train_nets: net.train() self.forward() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() ################################################################## # Helper functions ################################################################## def update_learning_rate(self, epoch): if epoch > self.opt['niter']: lrd = self.opt['lr'] / self.opt['niter_decay'] new_lr = self.old_lr - lrd else: new_lr = self.old_lr if new_lr != self.old_lr: new_lr_G = new_lr / 2 new_lr_D = new_lr * 2 for param_group in self.optimizer_D.param_groups: param_group['lr'] = new_lr_D for param_group in self.optimizer_G.param_groups: param_group['lr'] = new_lr_G print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) self.old_lr = new_lr def dxdy_reg_loss(self, y_hat, y): ''' y_hat, y: B 2 H W return: B H W ''' y_norm = y_hat / (torch.norm(y_hat, dim=1, keepdim=True) + 0.0000001) cos = torch.abs(torch.sum(y_norm * y, dim=1, keepdim=False)) norm = torch.abs( torch.norm(y_hat, dim=1, keepdim=False) - torch.ones_like(cos)) return 1 - cos + norm def discriminate(self, fake_image, real_image): fake_concat = torch.cat([fake_image], dim=1) real_concat = torch.cat([real_image], dim=1) # In Batch Normalization, the fake and real images are # recommended to be in the same batch to avoid disparate # statistics in fake and real images. # So both fake and real images are fed to D all at once. fake_and_real = torch.cat([fake_concat, real_concat], dim=0) discriminator_out = self.netD(fake_and_real) pred_fake, pred_real = self.divide_pred(discriminator_out) return pred_fake, pred_real # Take the prediction of fake and real images from the combined batch def divide_pred(self, pred): # the prediction contains the intermediate outputs of multiscale GAN, # so it's usually a list if type(pred) == list: fake = [] real = [] for p in pred: fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) real.append([tensor[tensor.size(0) // 2:] for tensor in p]) else: fake = pred[:pred.size(0) // 2] real = pred[pred.size(0) // 2:] return fake, real def inference(self, data): with torch.no_grad(): image = data['image'].to(self.device) mask = data['mask'].to(self.device) mask_ = mask.unsqueeze(1).type(image.dtype) pred_dxdy = self.netG(torch.cat([image, mask_], dim=1)) masked_pred_dxdy = torch.where(mask_ > 0., pred_dxdy, -torch.ones_like(pred_dxdy)) return {'image': image, 'mask': mask, 'pred_dxdy': pred_dxdy}
class Pix2PixTrainer(): """ Trainer creates the model and optimizers, and uses them to updates the weights of the network while reporting losses and the latest visuals to visualize the progress in training. """ def __init__(self, opt): self.opt = opt self.pix2pix_model = Pix2PixModel(opt) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model.cuda() self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr # print(self.pix2pix_model_on_one_gpu.netG) # print(self.pix2pix_model_on_one_gpu.netD) def run_generator_one_step(self, data): self.optimizer_G.zero_grad() g_losses, generated, masked, semantics = self.pix2pix_model( data, mode='generator') g_loss = sum(g_losses.values()).mean() g_loss.backward() self.optimizer_G.step() self.g_losses = g_losses self.generated = generated self.masked = masked self.semantics = semantics def run_discriminator_one_step(self, data): self.optimizer_D.zero_grad() d_losses = self.pix2pix_model(data, mode='discriminator') d_loss = sum(d_losses.values()).mean() d_loss.backward() def run_discriminator_one_step(self, data): self.optimizer_D.zero_grad() d_losses = self.pix2pix_model(data, mode='discriminator') d_loss = sum(d_losses.values()).mean() d_loss.backward() self.optimizer_D.step() self.d_losses = d_losses def get_latest_losses(self): return {**self.g_losses, **self.d_losses} def get_latest_generated(self): return self.generated def get_latest_real(self): return self.pix2pix_model_on_one_gpu.real_shape def get_semantics(self): return self.semantics def get_mask(self): if self.masked.shape[1] == 3: return self.masked else: return self.masked[:, :3] def update_learning_rate(self, epoch): self.update_learning_rate(epoch) def save(self, epoch): self.pix2pix_model_on_one_gpu.save(epoch) ################################################################## # Helper functions ################################################################## def update_learning_rate(self, epoch): if epoch > self.opt.niter: lrd = self.opt.lr / self.opt.niter_decay new_lr = self.old_lr - lrd else: new_lr = self.old_lr if new_lr != self.old_lr: if self.opt.no_TTUR: new_lr_G = new_lr new_lr_D = new_lr else: new_lr_G = new_lr / 2 new_lr_D = new_lr * 2 for param_group in self.optimizer_D.param_groups: param_group['lr'] = new_lr_D for param_group in self.optimizer_G.param_groups: param_group['lr'] = new_lr_G print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) self.old_lr = new_lr
class RotateTrainer(object): """ Trainer creates the model and optimizers, and uses them to updates the weights of the network while reporting losses and the latest visuals to visualize the progress in training. """ def __init__(self, opt): self.opt = opt self.pix2pix_model = create_model(opt) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids, output_device=opt.gpu_ids[-1], chunk_size=opt.chunk_size) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model # self.Render = networks.Render(opt, render_size=opt.crop_size) self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr def use_gpu(self): return len(self.opt.gpu_ids) > 0 def run_generator_one_step(self, data): self.optimizer_G.zero_grad() g_losses, generated = self.pix2pix_model.forward(data=data, mode='generator') if not self.opt.train_rotate: with torch.no_grad(): g_rotate_losses, generated_rotate = self.pix2pix_model.forward( data=data, mode='generator_rotated') else: g_rotate_losses, generated_rotate = self.pix2pix_model.forward( data=data, mode='generator_rotated') g_losses['GAN_rotate'] = g_rotate_losses['GAN'] g_loss = sum(g_losses.values()).mean() g_loss.backward() # g_rotate_loss = sum(g_rotate_losses.values()).mean() # g_rotate_loss.backward() self.optimizer_G.step() self.g_losses = g_losses # self.g_rotate_losses = g_rotate_losses self.generated = generated self.generated_rotate = generated_rotate def run_discriminator_one_step(self, data): self.optimizer_D.zero_grad() d_losses = self.pix2pix_model.forward(data=data, mode='discriminator') if self.opt.train_rotate: d_rotated_losses = self.pix2pix_model.forward( data=data, mode='discriminator_rotated') d_losses['D_rotate_Fake'] = d_rotated_losses['D_Fake'] d_losses['D_rotate_real'] = d_rotated_losses['D_real'] d_loss = sum(d_losses.values()).mean() d_loss.backward() self.optimizer_D.step() self.d_losses = d_losses def get_latest_generated(self): return self.generated def get_latest_generated_rotate(self): return self.generated_rotate def get_latest_losses(self): return {**self.g_losses, **self.d_losses} def get_current_visuals(self, data): return OrderedDict([('input_mesh', data['mesh']), ('input_rotated_mesh', data['rotated_mesh']), ('synthesized_image', self.get_latest_generated()), ('synthesized_rotated_image', self.get_latest_generated_rotate()), ('real_image', data['image'])]) def save(self, epoch): self.pix2pix_model_on_one_gpu.save(epoch) ################################################################## # Helper functions ################################################################## def update_learning_rate(self, epoch): if epoch > self.opt.niter: lrd = self.opt.lr / self.opt.niter_decay new_lr = self.old_lr - lrd else: new_lr = self.old_lr if new_lr != self.old_lr: if self.opt.no_TTUR: new_lr_G = new_lr new_lr_D = new_lr else: new_lr_G = new_lr / 2 new_lr_D = new_lr * 2 for param_group in self.optimizer_D.param_groups: param_group['lr'] = new_lr_D for param_group in self.optimizer_G.param_groups: param_group['lr'] = new_lr_G print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) self.old_lr = new_lr
class Pix2PixTrainer(): """ Trainer creates the model and optimizers, and uses them to updates the weights of the network while reporting losses and the latest visuals to visualize the progress in training. """ def __init__(self, opt, resume_epoch=0): self.opt = opt self.pix2pix_model = Pix2PixModel(opt) if len(opt.gpu_ids) > 1: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model.to(opt.gpu_ids[0]) self.pix2pix_model_on_one_gpu = self.pix2pix_model if opt.use_ema: self.netG_ema = EMA(opt.ema_beta) for name, param in self.pix2pix_model_on_one_gpu.net[ 'netG'].named_parameters(): if param.requires_grad: self.netG_ema.register(name, param.data) self.netCorr_ema = EMA(opt.ema_beta) for name, param in self.pix2pix_model_on_one_gpu.net[ 'netCorr'].named_parameters(): if param.requires_grad: self.netCorr_ema.register(name, param.data) self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr if opt.continue_train and opt.which_epoch == 'latest': checkpoint = torch.load( os.path.join(opt.checkpoints_dir, opt.name, 'optimizer.pth')) self.optimizer_G.load_state_dict(checkpoint['G']) self.optimizer_D.load_state_dict(checkpoint['D']) self.last_data, self.last_netCorr, self.last_netG, self.last_optimizer_G = None, None, None, None def run_generator_one_step(self, data, alpha=1): self.optimizer_G.zero_grad() g_losses, out = self.pix2pix_model(data, mode='generator', alpha=alpha) g_loss = sum(g_losses.values()).mean() g_loss.backward() self.optimizer_G.step() self.g_losses = g_losses self.out = out if self.opt.use_ema: self.netG_ema(self.pix2pix_model_on_one_gpu.net['netG']) self.netCorr_ema(self.pix2pix_model_on_one_gpu.net['netCorr']) def run_discriminator_one_step(self, data): self.optimizer_D.zero_grad() GforD = {} GforD['fake_image'] = self.out['fake_image'] GforD['adaptive_feature_seg'] = self.out['adaptive_feature_seg'] GforD['adaptive_feature_img'] = self.out['adaptive_feature_img'] d_losses = self.pix2pix_model(data, mode='discriminator', GforD=GforD) d_loss = sum(d_losses.values()).mean() d_loss.backward() self.optimizer_D.step() self.d_losses = d_losses def get_latest_losses(self): return {**self.g_losses, **self.d_losses} def get_latest_generated(self): return self.out['fake_image'] def update_learning_rate(self, epoch): self.update_learning_rate(epoch) def save(self, epoch): self.pix2pix_model_on_one_gpu.save(epoch) if self.opt.use_ema: self.netG_ema.assign(self.pix2pix_model_on_one_gpu.net['netG']) util.save_network(self.pix2pix_model_on_one_gpu.net['netG'], 'G_ema', epoch, self.opt) self.netG_ema.resume(self.pix2pix_model_on_one_gpu.net['netG']) self.netCorr_ema.assign( self.pix2pix_model_on_one_gpu.net['netCorr']) util.save_network(self.pix2pix_model_on_one_gpu.net['netCorr'], 'netCorr_ema', epoch, self.opt) self.netCorr_ema.resume( self.pix2pix_model_on_one_gpu.net['netCorr']) if epoch == 'latest': torch.save( { 'G': self.optimizer_G.state_dict(), 'D': self.optimizer_D.state_dict(), 'lr': self.old_lr, }, os.path.join(self.opt.checkpoints_dir, self.opt.name, 'optimizer.pth')) ################################################################## # Helper functions ################################################################## def update_learning_rate(self, epoch): if epoch > self.opt.niter: lrd = self.opt.lr / self.opt.niter_decay new_lr = self.old_lr - lrd else: new_lr = self.old_lr if new_lr != self.old_lr: if self.opt.no_TTUR: new_lr_G = new_lr new_lr_D = new_lr else: new_lr_G = new_lr / 2 new_lr_D = new_lr * 2 for param_group in self.optimizer_D.param_groups: param_group['lr'] = new_lr_D for param_group in self.optimizer_G.param_groups: param_group['lr'] = new_lr_G print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) self.old_lr = new_lr def update_fixed_params(self): for param in self.pix2pix_model_on_one_gpu.net['netCorr'].parameters(): param.requires_grad = True G_params = [{ 'params': self.pix2pix_model_on_one_gpu.net['netG'].parameters(), 'lr': self.opt.lr * 0.5 }] G_params += [{ 'params': self.pix2pix_model_on_one_gpu.net['netCorr'].parameters(), 'lr': self.opt.lr * 0.5 }] if self.opt.no_TTUR: beta1, beta2 = self.opt.beta1, self.opt.beta2 G_lr = self.opt.lr else: beta1, beta2 = 0, 0.9 G_lr = self.opt.lr / 2 self.optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2), eps=1e-3)
class Pix2PixTrainer(): """ Trainer creates the model and optimizers, and uses them to updates the weights of the network while reporting losses and the latest visuals to visualize the progress in training. """ def __init__(self, opt): self.opt = opt if self.opt.model == 'pix2pix': self.pix2pix_model = Pix2pixModel(opt) elif self.opt.model == 'smis': self.pix2pix_model = SmisModel(opt) print(self.pix2pix_model) with open(os.path.join(opt.checkpoints_dir, opt.name, 'model.txt'), 'w') as f: f.write(self.pix2pix_model.__str__()) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr def run_generator_one_step(self, data): self.optimizer_G.zero_grad() g_losses, generated = self.pix2pix_model(data, mode='generator') g_loss = sum(g_losses.values()).mean() g_loss.backward() self.optimizer_G.step() self.g_losses = g_losses self.generated = generated def run_discriminator_one_step(self, data): self.optimizer_D.zero_grad() d_losses = self.pix2pix_model(data, mode='discriminator') d_loss = sum(d_losses.values()).mean() d_loss.backward() self.optimizer_D.step() self.d_losses = d_losses def clean_grad(self): self.optimizer_D.zero_grad() self.optimizer_G.zero_grad() def get_latest_losses(self): return {**self.g_losses, **self.d_losses} def get_latest_generated(self): return self.generated def update_learning_rate(self, epoch): self.update_learning_rate(epoch) def save(self, epoch): self.pix2pix_model_on_one_gpu.save(epoch) ################################################################## # Helper functions ################################################################## def update_learning_rate(self, epoch): if epoch > self.opt.niter: lrd = self.opt.lr / self.opt.niter_decay new_lr = self.old_lr - lrd else: new_lr = self.old_lr if new_lr != self.old_lr: if self.opt.no_TTUR: new_lr_G = new_lr new_lr_D = new_lr else: new_lr_G = new_lr / 2 new_lr_D = new_lr * 2 for param_group in self.optimizer_D.param_groups: param_group['lr'] = new_lr_D for param_group in self.optimizer_G.param_groups: param_group['lr'] = new_lr_G print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) self.old_lr = new_lr