def init(self): opt = self.args if not os.path.exists(opt.saved_dir): os.makedirs(opt.saved_dir) self.fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.pool_size) self.crit_cycle = torch.nn.L1Loss() self.crit_idt = torch.nn.L1Loss() self.crit_gan = GANLoss(opt.gan_mode).cuda() self.cam_loss = CAMLoss() self.optim_G = torch.optim.Adam(itertools.chain( self.model.G_A.parameters(), self.model.G_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optim_D = torch.optim.Adam( itertools.chain(self.model.D_A.parameters(), self.model.D_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) # default: 0.5 self.optimizers = [self.optim_G, self.optim_D] self.schedulers = [ get_scheduler(optimizer, self.args) for optimizer in self.optimizers ]
def define_loss(self): # ------------------------------------ # G_loss # ------------------------------------ if self.opt_train['G_lossfn_weight'] > 0: G_lossfn_type = self.opt_train['G_lossfn_type'] if G_lossfn_type == 'l1': self.G_lossfn = nn.L1Loss().to(self.device) elif G_lossfn_type == 'l2': self.G_lossfn = nn.MSELoss().to(self.device) elif G_lossfn_type == 'l2sum': self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device) elif G_lossfn_type == 'ssim': self.G_lossfn = SSIMLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not found.'.format(G_lossfn_type)) self.G_lossfn_weight = self.opt_train['G_lossfn_weight'] else: print('Do not use pixel loss.') self.G_lossfn = None # ------------------------------------ # F_loss # ------------------------------------ if self.opt_train['F_lossfn_weight'] > 0: F_lossfn_type = self.opt_train['F_lossfn_type'] F_use_input_norm = self.opt_train['F_use_input_norm'] F_feature_layer = self.opt_train['F_feature_layer'] if self.opt['dist']: self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer, use_input_norm=F_use_input_norm, lossfn_type=F_lossfn_type).to( self.device) else: self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer, use_input_norm=F_use_input_norm, lossfn_type=F_lossfn_type) self.F_lossfn.vgg = self.model_to_device(self.F_lossfn.vgg) self.F_lossfn.lossfn = self.F_lossfn.lossfn.to(self.device) self.F_lossfn_weight = self.opt_train['F_lossfn_weight'] else: print('Do not use feature loss.') self.F_lossfn = None # ------------------------------------ # D_loss # ------------------------------------ self.D_lossfn = GANLoss(self.opt_train['gan_type'], 1.0, 0.0).to(self.device) self.D_lossfn_weight = self.opt_train['D_lossfn_weight'] self.D_update_ratio = self.opt_train[ 'D_update_ratio'] if self.opt_train['D_update_ratio'] else 1 self.D_init_iters = self.opt_train['D_init_iters'] if self.opt_train[ 'D_init_iters'] else 0
def define_loss(self): # ------------------------------------ # G_loss # ------------------------------------ if self.opt_train['G_lossfn_weight'] > 0: G_lossfn_type = self.opt_train['G_lossfn_type'] if G_lossfn_type == 'l1': self.G_lossfn = nn.L1Loss().to(self.device) elif G_lossfn_type == 'l2': self.G_lossfn = nn.MSELoss().to(self.device) elif G_lossfn_type == 'l2sum': self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device) elif G_lossfn_type == 'ssim': self.G_lossfn = SSIMLoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not found.'.format(G_lossfn_type)) self.G_lossfn_weight = self.opt_train['G_lossfn_weight'] else: print('Do not use pixel loss.') self.G_lossfn = None # ------------------------------------ # F_loss # ------------------------------------ if self.opt_train['F_lossfn_weight'] > 0: F_lossfn_type = self.opt_train['F_lossfn_type'] if F_lossfn_type == 'l1': self.F_lossfn = nn.L1Loss().to(self.device) elif F_lossfn_type == 'l2': self.F_lossfn = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(F_lossfn_type)) self.F_lossfn_weight = self.opt_train['F_lossfn_weight'] # self.netF = define_F(self.opt, use_bn=False).to(self.device) else: print('Do not use feature loss.') self.F_lossfn = None # ------------------------------------ # D_loss # ------------------------------------ self.D_lossfn = GANLoss(self.opt_train['gan_type'], 1.0, 0.0).to(self.device) self.D_lossfn_weight = self.opt_train['D_lossfn_weight'] self.D_update_ratio = self.opt_train[ 'D_update_ratio'] if self.opt_train['D_update_ratio'] else 1 self.D_init_iters = self.opt_train['D_init_iters'] if self.opt_train[ 'D_init_iters'] else 0
def get_criterion(self, mode, opt): if mode == 'pix': loss_type = opt['pixel_criterion'] if loss_type == 'l1': criterion = nn.L1Loss(reduction=opt['reduction']).to( self.device) elif loss_type == 'l2': criterion = nn.MSELoss(reduction=opt['reduction']).to( self.device) elif loss_type == 'cb': criterion = CharbonnierLoss(reduction=opt['reduction']).to( self.device) else: raise NotImplementedError( 'Loss type [{:s}] is not recognized for pixel'.format( loss_type)) weight = opt['pixel_weight'] elif mode == 'gan': criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device) weight = opt['gan_weight'] else: raise TypeError('Unknown type: {} for criterion'.format(mode)) return criterion, weight
def __init__(self, args): super(PPONModel, self).__init__(args) # define networks and load pre-trained models self.netG = networks.define_G(args).cuda() if self.is_train: if args.which_model == 'perceptual': self.netD = networks.define_D().cuda() self.netD.train() self.netG.train() self.load() # load G and D if needed # define losses, optimizer and scheduler if self.is_train: # G pixel loss if args.pixel_weight > 0: l_pix_type = args.pixel_criterion if l_pix_type == 'l1': # loss pixel type self.cri_pix = nn.L1Loss().cuda() elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().cuda() else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = args.pixel_weight else: print('Remove pixel loss.') self.cri_pix = None # critic pixel # G structure loss if args.structure_weight > 0: self.cri_msssim = pytorch_msssim.MS_SSIM(data_range=args.rgb_range).cuda() self.cri_ml1 = MultiscaleL1Loss().cuda() else: print('Remove structure loss.') self.cri_msssim = None self.cri_ml1 = None # G feature loss if args.feature_weight > 0: l_fea_type = args.feature_criterion if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().cuda() elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().cuda() else: raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = args.feature_weight else: print('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.vgg = networks.define_F().cuda() if args.gan_weight > 0: # gan loss self.cri_gan = GANLoss(args.gan_type, 1.0, 0.0).cuda() self.l_gan_w = args.gan_weight else: self.cri_gan = None # optimizers # G if args.which_model == 'structure': for param in self.netG.CFEM.parameters(): param.requires_grad = False for param in self.netG.CRM.parameters(): param.requires_grad = False if args.which_model == 'perceptual': for param in self.netG.CFEM.parameters(): param.requires_grad = False for param in self.netG.CRM.parameters(): param.requires_grad = False for param in self.netG.SFEM.parameters(): param.requires_grad = False for param in self.netG.SRM.parameters(): param.requires_grad = False optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: print('Warning: params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=args.lr_G) self.optimizers.append(self.optimizer_G) # D if args.which_model == 'perceptual': self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=args.lr_D) self.optimizers.append(self.optimizer_D) # schedulers if args.lr_scheme == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, args.lr_steps, args.lr_gamma)) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('------------- Model initialized -------------') self.print_network() print('---------------------------------------------')
def __init__(self, opt): super(SRGANModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: self.netF = DistributedDataParallel( self.netF, device_ids=[torch.cuda.current_device()]) else: self.netF = DataParallel(self.netF) # G Rank-content loss if train_opt['R_weight'] > 0: self.l_R_w = train_opt['R_weight'] # load rank-content loss self.R_bias = train_opt['R_bias'] self.netR = networks.define_R(opt).to(self.device) if opt['dist']: self.netR = DistributedDataParallel( self.netR, device_ids=[torch.cuda.current_device()]) else: self.netR = DataParallel(self.netR) else: logger.info('Remove rank-content loss.') # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed
optimizers.append(optimizer_D) # Handle multi-gpu if desired if (device1.type == 'cuda') and (config['ngpu'] > 1): netG = nn.DataParallel(netG, list(range(config['ngpu']))) netD = nn.DataParallel(netD, list(range(config['ngpu']))) netF = nn.DataParallel(netF, list(range(config['ngpu']))) # summary(netG, input_size=(3, input_shape, input_shape), device="cuda") # summary(netD, input_size=(3, output_shape, output_shape), device="cuda") # G pixel loss cri_pix = nn.L1Loss().to(device1) # G feature loss cri_fea = nn.L1Loss().to(device1) # GD gan loss cri_gan = GANLoss("vanilla", 1.0, 0.0).to(device1) # schedulers schedulers = list() for optimizer in optimizers: schedulers.append( lr_scheduler.MultiStepLR(optimizer, [50, 75, 100, 200], 0.5)) log_dict = OrderedDict() global_step = config['n_epoch_start'] * train_loader.__len__() for i in range(config['n_epoch_start']): for scheduler in schedulers: scheduler.step() for epoch in trange(config['n_epoch_start'], config['n_epoch_end']):
def __init__(self, cfg, local_cfg): self.cfg = cfg self.local_cfg = local_cfg self.device = torch.device(self.local_cfg.gpu) # setup models self.cfg.model.gen.shape = self.cfg.dataset.shape self.cfg.model.dis.shape = self.cfg.dataset.shape self.G = define_G(self.cfg) self.D = define_D(self.cfg) self.G_ema = define_G(self.cfg) self.G_ema.eval() ema_inplace(self.G_ema, self.G, 0.0) self.A = DiffAugment(policy=self.cfg.solver.augment) self.lidar = LiDAR( num_ring=cfg.dataset.shape[0], num_points=cfg.dataset.shape[1], min_depth=cfg.dataset.min_depth, max_depth=cfg.dataset.max_depth, angle_file=osp.join(cfg.dataset.root, "angles.pt"), ) self.lidar.eval() self.G.to(self.device) self.D.to(self.device) self.G_ema.to(self.device) self.A.to(self.device) self.lidar.to(self.device) self.G = DDP(self.G, device_ids=[self.local_cfg.gpu], broadcast_buffers=False) self.D = DDP(self.D, device_ids=[self.local_cfg.gpu], broadcast_buffers=False) if dist.get_rank() == 0: print("minibatch size per gpu:", self.local_cfg.batch_size) print("number of gradient accumulation:", self.cfg.solver.num_accumulation) self.ema_decay = 0.5**(self.cfg.solver.batch_size / (self.cfg.solver.smoothing_kimg * 1000)) # training dataset self.dataset = define_dataset(self.cfg.dataset, phase="train") self.loader = torch.utils.data.DataLoader( self.dataset, batch_size=self.local_cfg.batch_size, shuffle=False, num_workers=self.local_cfg.num_workers, pin_memory=self.cfg.pin_memory, sampler=torch.utils.data.distributed.DistributedSampler( self.dataset), drop_last=True, ) self.loader = cycle(self.loader) # validation dataset self.val_dataset = define_dataset(self.cfg.dataset, phase="val") self.val_loader = torch.utils.data.DataLoader( self.val_dataset, batch_size=self.local_cfg.batch_size, shuffle=True, num_workers=self.local_cfg.num_workers, pin_memory=self.cfg.pin_memory, drop_last=False, ) # loss criterion self.loss_weight = dict(self.cfg.solver.loss) self.criterion = {} self.criterion["gan"] = GANLoss(self.cfg.solver.gan_mode).to( self.device) if "gp" in self.loss_weight and self.loss_weight["gp"] > 0.0: self.criterion["gp"] = True if "pl" in self.loss_weight and self.loss_weight["pl"] > 0.0: self.criterion["pl"] = True self.pl_ema = torch.tensor(0.0).to(self.device) if dist.get_rank() == 0: print("loss: {}".format(tuple(self.criterion.keys()))) # optimizer self.optim_G = optim.Adam( params=self.G.parameters(), lr=self.cfg.solver.lr.alpha.gen, betas=(self.cfg.solver.lr.beta1, self.cfg.solver.lr.beta2), ) self.optim_D = optim.Adam( params=self.D.parameters(), lr=self.cfg.solver.lr.alpha.dis, betas=(self.cfg.solver.lr.beta1, self.cfg.solver.lr.beta2), ) # automatic mixed precision self.enable_amp = cfg.enable_amp self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp) if dist.get_rank() == 0 and self.enable_amp: print("amp enabled") # resume from checkpoints self.start_iteration = 0 if self.cfg.resume is not None: state_dict = torch.load(self.cfg.resume, map_location="cpu") self.start_iteration = state_dict[ "step"] // self.cfg.solver.batch_size self.G.module.load_state_dict(state_dict["G"]) self.D.module.load_state_dict(state_dict["D"]) self.G_ema.load_state_dict(state_dict["G_ema"]) self.optim_G.load_state_dict(state_dict["optim_G"]) self.optim_D.load_state_dict(state_dict["optim_D"]) if "pl" in self.criterion: self.criterion["pl"].pl_ema = state_dict["pl_ema"].to( self.device) # for visual validation self.fixed_noise = torch.randn(self.local_cfg.batch_size, cfg.model.gen.in_ch, device=self.device)
def __init__(self, opt): super(InpaintingModel, self).__init__(opt) train_opt = opt['train'] # define networks and load pretrained model self.netG = networks.define_G(opt).to(self.device) if self.is_train: self.netD = networks.define_D(opt).to(self.device) self.netG.train() self.netD.train() self.load() # load G and D # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) elif l_pix_type == 'ml1': self.cri_pix = MultiscaleL1Loss().to(self.device) else: raise NotImplementedError('Unsupported loss type: {}'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Unsupported loss type: {}'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] self.guided_cri_fea = MaskedL1Loss().to(self.device) else: self.cri_fea = None if self.cri_fea: # load VGG model # self.vgg = Vgg19() # self.vgg.load_state_dict(torch.load(vgg_model)) # for param in self.vgg.parameters(): # param.requires_grad = False self.vgg = networks.define_F(opt) self.vgg.to(self.device) self.vgg_layers = ['r11', 'r21', 'r31', 'r41', 'r51'] self.vgg_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]] self.vgg_fns = [self.cri_fea] * len(self.vgg_layers) ## discriminator features if train_opt['dis_feature_weight'] > 0: l_dis_fea_type = train_opt['dis_feature_criterion'] if l_dis_fea_type == 'l1': self.cri_dis_fea = nn.L1Loss().to(self.device) elif l_dis_fea_type == 'l2': self.cri_dis_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError('Unsupported loss type: {}'.format(l_dis_fea_type)) self.l_dis_fea_w = train_opt['dis_feature_weight'] else: self.cri_dis_fea = None if self.cri_dis_fea: self.dis_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]] self.dis_fns = [self.cri_dis_fea] * len(self.dis_weights) ## center loss weight if train_opt['center_weight'] > 0: self.l_center_w = train_opt['center_weight'] else: self.l_center_w = 0 # G & D gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # optimizers optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: print('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], betas=(0.5, 0.999)) self.optimizers.append(self.optimizer_G) # D self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], betas=(0.5, 0.999)) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_policy'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('Unsupported learning scheme: {}'.format(train_opt['lr_policy'])) self.log_dict = OrderedDict() # print network self.print_network()
def main(args): opt = BaseOptions.parse(args) make_dir(opt.checkpoints_dir) BaseOptions.print_options(opt) torch.backends.cudnn.benchmark = True device = torch.device( "cuda:{}".format(opt.gpu_ids[0]) if opt.gpu_ids else "cpu") net_D = Discriminator(opt.input_nc, opt.conv_dim_d, opt.n_layers_d, opt.use_sigmoid).to(device) net_G = Generator(opt.input_nc, opt.conv_dim_g, opt.n_blocks_g, opt.use_bias).to(device) if opt.resume_iters: load_net(net_D, opt.resume_iters, "D", device) load_net(net_G, opt.resume_iters, "G", device) else: init_weights(net_D, opt.init_type) init_weights(net_G, opt.init_type) if len(opt.gpu_ids) > 1: net_D = torch.nn.DataParallel(net_D, device_ids=opt.gpu_ids) net_G = torch.nn.DataParallel(net_G, device_ids=opt.gpu_ids) print_network(net_D, "net_D") print_network(net_G, "net_G") optimizer_D = torch.optim.Adam(net_D.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_G = torch.optim.Adam(net_G.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizers = [optimizer_D, optimizer_G] schedulers = [get_scheduler(optimizer, opt) for optimizer in optimizers] criterionGAN = GANLoss(no_lsgan=True).to(device) criterionL1 = torch.nn.L1Loss() dataset = Dataset(opt.json_file, opt.aligned) print("#training images = {}".format(len(dataset))) data_loader = get_loader(dataset, opt.batch_size, True, opt.workers) data_time = 0.0 total_time = 0.0 data_iter = iter(data_loader) logger = Logger(opt.checkpoints_dir) for curr_iters in range(opt.start_iters, opt.start_iters + opt.train_iters): start_time = time.time() try: real_A, real_B = next(data_iter) except Exception: data_iter = iter(data_loader) real_A, real_B = next(data_iter) real_A = real_A.to(device) real_B = real_B.to(device) data_time += time.time() - start_time fake_B = net_G(real_A) # update D set_requires_grad(net_D, True) optimizer_D.zero_grad() pred_fake = net_D(fake_B.detach()) loss_D_fake = criterionGAN(pred_fake, False) pred_real = net_D(real_B) loss_D_real = criterionGAN(pred_real, True) loss_D = loss_D_fake + loss_D_real loss_D.backward() optimizer_D.step() # update G set_requires_grad(net_D, False) optimizer_G.zero_grad() pred_fake = net_D(fake_B) loss_G_GAN = criterionGAN(pred_fake, True) loss_G_L1 = criterionL1(fake_B, real_B) loss_G = loss_G_GAN + loss_G_L1 * 100 loss_G.backward() optimizer_G.step() total_time += time.time() - start_time logger.add(loss_D_fake=loss_D_fake.mean().item(), loss_D_real=loss_D_real.mean().item(), loss_G_GAN=loss_G_GAN.mean().item(), loss_G_L1=loss_G_L1.mean().item()) for scheduler in schedulers: scheduler.step() if curr_iters % opt.model_save == 0: print("saving the model: iters = {}, lr = {}".format( curr_iters, optimizers[0].param_groups[0]["lr"])) save_net(net_D, curr_iters, "D", opt) save_net(net_G, curr_iters, "G", opt) if curr_iters % opt.display_freq == 0: print("#iters[{}]: data time {}, total time {}".format( curr_iters, data_time, total_time)) data_time = 0.0 total_time = 0.0 logger.save(curr_iters)
def __init__(self, opt, dataset=None): super(SRGANModel, self).__init__(opt) if dataset: self.cri_text = True if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training train_opt = opt['train'] # define networks and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel( self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) if self.is_train: self.netD = networks.define_D(opt).to(self.device) if opt['dist']: self.netD = DistributedDataParallel( self.netD, device_ids=[torch.cuda.current_device()]) else: self.netD = DataParallel(self.netD) self.netG.train() self.netD.train() # define losses, optimizer and scheduler if self.is_train: # G pixel loss if train_opt['pixel_weight'] > 0: l_pix_type = train_opt['pixel_criterion'] if l_pix_type == 'l1': self.cri_pix = nn.L1Loss().to(self.device) elif l_pix_type == 'l2': self.cri_pix = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_pix_type)) self.l_pix_w = train_opt['pixel_weight'] else: logger.info('Remove pixel loss.') self.cri_pix = None # G feature loss if train_opt['feature_weight'] > 0: l_fea_type = train_opt['feature_criterion'] if l_fea_type == 'l1': self.cri_fea = nn.L1Loss().to(self.device) elif l_fea_type == 'l2': self.cri_fea = nn.MSELoss().to(self.device) else: raise NotImplementedError( 'Loss type [{:s}] not recognized.'.format(l_fea_type)) self.l_fea_w = train_opt['feature_weight'] else: logger.info('Remove feature loss.') self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) if opt['dist']: pass # do not need to use DistributedDataParallel for netF else: self.netF = DataParallel(self.netF) if self.cri_text: from lib.models.model_builder import ModelBuilder self.netT = ModelBuilder( arch="ResNet_ASTER", rec_num_classes=dataset.rec_num_classes, sDim=512, attDim=512, max_len_labels=100, eos=dataset.char2id[dataset.EOS], STN_ON=True).to(self.device) self.netT = DataParallel(self.netT) self.netT.eval() from lib.util.serialization import load_checkpoint checkpoint = load_checkpoint(train_opt['text_model']) self.netT.load_state_dict(checkpoint['state_dict']) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.l_gan_w = train_opt['gan_weight'] # D_update_ratio and D_init_iters self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[ 'D_update_ratio'] else 1 self.D_init_iters = train_opt['D_init_iters'] if train_opt[ 'D_init_iters'] else 0 # optimizers # G wd_G = train_opt['weight_decay_G'] if train_opt[ 'weight_decay_G'] else 0 optim_params = [] for k, v in self.netG.named_parameters( ): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning( 'Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) # D wd_D = train_opt['weight_decay_D'] if train_opt[ 'weight_decay_D'] else 0 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], weight_decay=wd_D, betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart( optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError( 'MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() self.print_network() # print network self.load() # load G and D if needed