def create_criterion(): face_criterion = FaceLoss( pretrained_path=self._opt.face_model).cuda() idt_criterion = torch.nn.L1Loss() mask_criterion = torch.nn.BCELoss() return face_criterion, idt_criterion, mask_criterion
def _init_losses(self): # define loss functions multi_gpus = len(self._gpu_ids) > 1 self._crt_l1 = torch.nn.L1Loss() if self._opt.mask_bce: self._crt_mask = torch.nn.BCELoss() else: self._crt_mask = torch.nn.MSELoss() vgg_net = Vgg19() if self._opt.use_vgg: self._criterion_vgg = VGGLoss(vgg=vgg_net) if multi_gpus: self._criterion_vgg = torch.nn.DataParallel( self._criterion_vgg) self._criterion_vgg.cuda() if self._opt.use_style: self._criterion_style = StyleLoss(feat_extractors=vgg_net) if multi_gpus: self._criterion_style = torch.nn.DataParallel( self._criterion_style) self._criterion_style.cuda() if self._opt.use_face: self._criterion_face = FaceLoss( pretrained_path=self._opt.face_model) if multi_gpus: self._criterion_face = torch.nn.DataParallel( self._criterion_face) self._criterion_face.cuda() # init losses G self._loss_g_l1 = self._Tensor([0]) self._loss_g_vgg = self._Tensor([0]) self._loss_g_style = self._Tensor([0]) self._loss_g_face = self._Tensor([0]) self._loss_g_adv = self._Tensor([0]) self._loss_g_smooth = self._Tensor([0]) self._loss_g_mask = self._Tensor([0]) self._loss_g_mask_smooth = self._Tensor([0]) # init losses D self._d_real = self._Tensor([0]) self._d_fake = self._Tensor([0])
class Impersonator(BaseModel): def __init__(self, opt): super(Impersonator, self).__init__(opt) self._name = 'Impersonator' # create networks self._init_create_networks() # init train variables and losses if self._is_train: self._init_train_vars() self._init_losses() # load networks and optimizers if not self._is_train or self._opt.load_epoch > 0: self.load() elif self._opt.load_path != 'None': # ipdb.set_trace() self._load_params(self._G, self._opt.load_path, need_module=len(self._gpu_ids) > 1) # prefetch variables self._init_prefetch_inputs() def _init_create_networks(self): multi_gpus = len(self._gpu_ids) > 1 # body recovery Flow self._bdr = BodyRecoveryFlow(opt=self._opt) if multi_gpus: self._bdr = torch.nn.DataParallel(self._bdr) self._bdr.eval() self._bdr.cuda() # generator network self._G = self._create_generator() self._G.init_weights() if multi_gpus: self._G = torch.nn.DataParallel(self._G) self._G.cuda() # discriminator network self._D = self._create_discriminator() self._D.init_weights() if multi_gpus: self._D = torch.nn.DataParallel(self._D) self._D.cuda() def _create_generator(self): return NetworksFactory.get_by_name(self._opt.gen_name, bg_dim=4, src_dim=3+self._G_cond_nc, tsf_dim=3+self._G_cond_nc, repeat_num=self._opt.repeat_num) def _create_discriminator(self): return NetworksFactory.get_by_name('discriminator_patch_gan', input_nc=3 + self._D_cond_nc, norm_type=self._opt.norm_type, ndf=64, n_layers=4, use_sigmoid=False, sn=self._opt.spectral_norm) def _init_train_vars(self): print("---------- Generator LR:{0} ---------- DISCRIMINATOR LR:{1} ----------".format(self._opt.lr_G, self._opt.lr_D)) self._current_lr_G = self._opt.lr_G self._current_lr_D = self._opt.lr_D # initialize optimizers self._optimizer_G = torch.optim.Adam(self._G.parameters(), lr=self._current_lr_G, betas=(self._opt.G_adam_b1, self._opt.G_adam_b2)) self._optimizer_D = torch.optim.Adam(self._D.parameters(), lr=self._current_lr_D, betas=(self._opt.D_adam_b1, self._opt.D_adam_b2)) def _init_prefetch_inputs(self): self._real_src = None self._real_tsf = None self._bg_mask = None self._input_src = None self._input_G_bg = None self._input_G_src = None self._input_G_tsf = None self._T = None self._body_bbox = None self._head_bbox = None def _init_losses(self): # define loss functions multi_gpus = len(self._gpu_ids) > 1 self._crt_l1 = torch.nn.L1Loss() if self._opt.mask_bce: self._crt_mask = torch.nn.BCELoss() else: self._crt_mask = torch.nn.MSELoss() vgg_net = Vgg19() if self._opt.use_vgg: self._crt_tsf = VGGLoss(vgg=vgg_net) if multi_gpus: self._crt_tsf = torch.nn.DataParallel(self._crt_tsf) self._crt_tsf.cuda() if self._opt.use_style: self._crt_style = StyleLoss(feat_extractors=vgg_net) if multi_gpus: self._crt_style = torch.nn.DataParallel(self._crt_style) self._crt_style.cuda() if self._opt.use_face: self._criterion_face = FaceLoss(pretrained_path=self._opt.face_model) if multi_gpus: self._criterion_face = torch.nn.DataParallel(self._criterion_face) self._criterion_face.cuda() # init losses G self._loss_g_rec = self._Tensor([0]) self._loss_g_tsf = self._Tensor([0]) self._loss_g_style = self._Tensor([0]) self._loss_g_face = self._Tensor([0]) self._loss_g_adv = self._Tensor([0]) self._loss_g_smooth = self._Tensor([0]) self._loss_g_mask = self._Tensor([0]) self._loss_g_mask_smooth = self._Tensor([0]) # init losses D self._d_real = self._Tensor([0]) self._d_fake = self._Tensor([0]) self._d_real_loss = self._Tensor([0]) self._d_fake_loss = self._Tensor([0]) def set_input(self, input): with torch.no_grad(): images = input['images'] smpls = input['smpls'] src_img = images[:, 0, ...].cuda() src_smpl = smpls[:, 0, ...].cuda() tsf_img = images[:, 1, ...].cuda() tsf_smpl = smpls[:, 1, ...].cuda() input_G_src_bg, input_G_tsf_bg, input_G_src, input_G_tsf, T, src_crop_mask, \ tsf_crop_mask, head_bbox, body_bbox = self._bdr(src_img, tsf_img, src_smpl, tsf_smpl) self._real_src = src_img self._real_tsf = tsf_img self._bg_mask = torch.cat((src_crop_mask, tsf_crop_mask), dim=0) if self._opt.bg_both: self._input_G_bg = torch.cat([input_G_src_bg, input_G_tsf_bg], dim=0) else: self._input_G_bg = input_G_src_bg self._input_G_src = input_G_src self._input_G_tsf = input_G_tsf self._T = T self._head_bbox = head_bbox self._body_bbox = body_bbox def set_train(self): self._G.train() self._D.train() self._is_train = True def set_eval(self): self._G.eval() self._is_train = False def forward(self, keep_data_for_visuals=False, return_estimates=False): # generate fake images fake_bg, fake_src_color, fake_src_mask, fake_tsf_color, fake_tsf_mask = \ self._G.forward(self._input_G_bg, self._input_G_src, self._input_G_tsf, T=self._T) bs = fake_src_color.shape[0] fake_src_bg = fake_bg[0:bs] if self._opt.bg_both: fake_tsf_bg = fake_bg[bs:] fake_src_imgs = fake_src_mask * fake_src_bg + (1 - fake_src_mask) * fake_src_color fake_tsf_imgs = fake_tsf_mask * fake_tsf_bg + (1 - fake_tsf_mask) * fake_tsf_color else: fake_src_imgs = fake_src_mask * fake_src_bg + (1 - fake_src_mask) * fake_src_color fake_tsf_imgs = fake_tsf_mask * fake_src_bg + (1 - fake_tsf_mask) * fake_tsf_color fake_masks = torch.cat([fake_src_mask, fake_tsf_mask], dim=0) # keep data for visualization if keep_data_for_visuals: self.visual_imgs(fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks) return fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks def optimize_parameters(self, trainable=True, keep_data_for_visuals=False): if self._is_train: # run inference fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks = self.forward(keep_data_for_visuals=keep_data_for_visuals) loss_G = self._optimize_G(fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks) self._optimizer_G.zero_grad() loss_G.backward() self._optimizer_G.step() # train D if trainable: loss_D = self._optimize_D(fake_tsf_imgs) self._optimizer_D.zero_grad() loss_D.backward(retain_graph=True) self._optimizer_D.step() def _optimize_G(self, fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks): fake_input_D = torch.cat([fake_tsf_imgs, self._input_G_tsf[:, 3:]], dim=1) d_fake_outs = self._D.forward(fake_input_D) self._loss_g_adv = self._compute_loss_D(d_fake_outs, 0) * self._opt.lambda_D_prob self._loss_g_rec = self._crt_l1(fake_src_imgs, self._real_src) * self._opt.lambda_rec if self._opt.use_vgg: self._loss_g_tsf = torch.mean(self._crt_tsf(fake_tsf_imgs, self._real_tsf)) * self._opt.lambda_tsf else: self._loss_g_tsf = torch.mean(self._crt_tsf(fake_tsf_imgs, self._real_tsf)) * self._opt.lambda_tsf if self._opt.use_style: self._loss_g_style = torch.mean(self._crt_style( fake_tsf_imgs, self._real_tsf)) * self._opt.lambda_style if self._opt.use_face: self._loss_g_face = torch.mean(self._criterion_face( fake_tsf_imgs, self._real_tsf, bbox1=self._head_bbox, bbox2=self._head_bbox)) * self._opt.lambda_face # loss mask self._loss_g_mask = self._crt_mask(fake_masks, self._bg_mask) * self._opt.lambda_mask if self._opt.lambda_mask_smooth != 0: self._loss_g_mask_smooth = self._compute_loss_smooth(fake_masks) * self._opt.lambda_mask_smooth # combine losses return self._loss_g_adv + self._loss_g_rec + self._loss_g_tsf + self._loss_g_style + self._loss_g_face + \ self._loss_g_mask + self._loss_g_mask_smooth def _optimize_D(self, fake_tsf_imgs): tsf_cond = self._input_G_tsf[:, 3:] fake_input_D = torch.cat([fake_tsf_imgs.detach(), tsf_cond], dim=1) real_input_D = torch.cat([self._real_tsf, tsf_cond], dim=1) d_real_outs = self._D.forward(real_input_D) d_fake_outs = self._D.forward(fake_input_D) if self._opt.label_smooth: _loss_d_real = self._compute_loss_D(d_real_outs, 0.9) * self._opt.lambda_D_prob else: _loss_d_real = self._compute_loss_D(d_real_outs, 1) * self._opt.lambda_D_prob _loss_d_fake = self._compute_loss_D(d_fake_outs, -1) * self._opt.lambda_D_prob self._d_real_loss = _loss_d_real self._d_fake_loss = _loss_d_fake self._d_real = torch.mean(d_real_outs) self._d_fake = torch.mean(d_fake_outs) # Gradient Penalty - Puneet # gp_weight = 2 if self._opt.gradient_penalty!=0: alpha = torch.rand(real_input_D.shape[0], 1, 1, 1) alpha = alpha.expand_as(real_input_D).cuda() interp_images = Variable(alpha * real_input_D.data + (1 - alpha) * fake_input_D.data, requires_grad=True).cuda() d_interp_outs = self._D.forward(interp_images) gradients = torch.autograd.grad(outputs=d_interp_outs, inputs=interp_images, grad_outputs=torch.ones(d_interp_outs.size()).cuda(), create_graph=True, retain_graph=True)[0] gradients = gradients.view(real_input_D.shape[0], -1) gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) gp = gp_weight * gradients_norm.mean() else: gp = 0 # combine losses return _loss_d_real + _loss_d_fake + gp def _compute_loss_D(self, x, y): return torch.mean((x - y) ** 2) def _compute_loss_smooth(self, mat): return torch.mean(torch.abs(mat[:, :, :, :-1] - mat[:, :, :, 1:])) + \ torch.mean(torch.abs(mat[:, :, :-1, :] - mat[:, :, 1:, :])) def get_current_errors(self): loss_dict = OrderedDict([('g_rec', self._loss_g_rec.item()), ('g_tsf', self._loss_g_tsf.item()), ('g_style', self._loss_g_style.item()), ('g_face', self._loss_g_face.item()), ('g_adv', self._loss_g_adv.item()), ('g_mask', self._loss_g_mask.item()), ('g_mask_smooth', self._loss_g_mask_smooth.item()), ('d_real', self._d_real.item()), ('d_fake', self._d_fake.item()), ('d_real_loss', self._d_real_loss.item()), ('d_fake_loss', self._d_fake_loss.item())]) return loss_dict def get_current_scalars(self): return OrderedDict([('lr_G', self._current_lr_G), ('lr_D', self._current_lr_D)]) def get_current_visuals(self): # visuals return dictionary visuals = OrderedDict() # inputs visuals['1_real_img'] = self._vis_input visuals['2_input_tsf'] = self._vis_tsf visuals['3_fake_bg'] = self._vis_fake_bg # outputs visuals['4_fake_tsf'] = self._vis_fake_tsf visuals['5_fake_src'] = self._vis_fake_src visuals['6_fake_mask'] = self._vis_mask # batch outputs visuals['7_batch_real_img'] = self._vis_batch_real visuals['8_batch_fake_img'] = self._vis_batch_fake return visuals @torch.no_grad() def visual_imgs(self, fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks): ids = fake_masks.shape[0] // 2 self._vis_input = util.tensor2im(self._real_src) self._vis_tsf = util.tensor2im(self._input_G_tsf[0, 0:3]) self._vis_fake_bg = util.tensor2im(fake_bg) self._vis_fake_src = util.tensor2im(fake_src_imgs) self._vis_fake_tsf = util.tensor2im(fake_tsf_imgs) self._vis_mask = util.tensor2maskim(fake_masks[ids]) self._vis_batch_real = util.tensor2im(self._real_tsf, idx=-1) self._vis_batch_fake = util.tensor2im(fake_tsf_imgs, idx=-1) def save(self, label): # save networks self._save_network(self._G, 'G', label) self._save_network(self._D, 'D', label) # save optimizers self._save_optimizer(self._optimizer_G, 'G', label) self._save_optimizer(self._optimizer_D, 'D', label) def load(self): load_epoch = self._opt.load_epoch # load G self._load_network(self._G, 'G', load_epoch, need_module=True) if self._is_train: # load D self._load_network(self._D, 'D', load_epoch, need_module=True) # load optimizers self._load_optimizer(self._optimizer_G, 'G', load_epoch) self._load_optimizer(self._optimizer_D, 'D', load_epoch) def update_learning_rate(self): # updated learning rate G final_lr = self._opt.final_lr lr_decay_G = (self._opt.lr_G - final_lr) / self._opt.nepochs_decay self._current_lr_G -= lr_decay_G for param_group in self._optimizer_G.param_groups: param_group['lr'] = self._current_lr_G print('update G learning rate: %f -> %f' % (self._current_lr_G + lr_decay_G, self._current_lr_G)) # update learning rate D lr_decay_D = (self._opt.lr_D - final_lr) / self._opt.nepochs_decay self._current_lr_D -= lr_decay_D for param_group in self._optimizer_D.param_groups: param_group['lr'] = self._current_lr_D print('update D learning rate: %f -> %f' % (self._current_lr_D + lr_decay_D, self._current_lr_D))
class Impersonator(BaseModel): def __init__(self, opt): super(Impersonator, self).__init__(opt) self._name = 'Impersonator' # create networks self._init_create_networks() # init train variables and losses if self._is_train: self._init_train_vars() self._init_losses() # load networks and optimizers if not self._is_train or self._opt.load_epoch > 0: self.load() # prefetch variables self._init_prefetch_inputs() def _init_create_networks(self): multi_gpus = len(self._gpu_ids) > 1 # body recovery Flow self._bdr = BodyRecoveryFlow(opt=self._opt) if multi_gpus: self._bdr = torch.nn.DataParallel(self._bdr) self._bdr.eval() self._bdr.cuda() # generator network self._G = self._create_generator() self._G.init_weights() self._G = torch.nn.DataParallel(self._G) self._G.cuda() # discriminator network self._D = self._create_discriminator() self._D.init_weights() self._D = torch.nn.DataParallel(self._D) self._D.cuda() def _create_generator(self): return NetworksFactory.get_by_name(self._opt.gen_name, bg_dim=4, src_dim=3 + self._G_cond_nc, tsf_dim=3 + self._G_cond_nc, repeat_num=self._opt.repeat_num) def _create_discriminator(self): return NetworksFactory.get_by_name('global_local', input_nc=3 + self._D_cond_nc // 2, norm_type=self._opt.norm_type, ndf=64, n_layers=4, use_sigmoid=False) def _init_train_vars(self): self._current_lr_G = self._opt.lr_G self._current_lr_D = self._opt.lr_D # initialize optimizers self._optimizer_G = torch.optim.Adam(self._G.parameters(), lr=self._current_lr_G, betas=(self._opt.G_adam_b1, self._opt.G_adam_b2)) self._optimizer_D = torch.optim.Adam(self._D.parameters(), lr=self._current_lr_D, betas=(self._opt.D_adam_b1, self._opt.D_adam_b2)) def _init_prefetch_inputs(self): self._real_bg = None self._real_src = None self._real_tsf = None self._bg_mask = None self._input_G_aug_bg = None self._input_G_src = None self._input_G_tsf = None self._head_bbox = None self._body_bbox = None self._T = None def _init_losses(self): # define loss functions multi_gpus = len(self._gpu_ids) > 1 self._crt_l1 = torch.nn.L1Loss() if self._opt.mask_bce: self._crt_mask = torch.nn.BCELoss() else: self._crt_mask = torch.nn.MSELoss() vgg_net = Vgg19() if self._opt.use_vgg: self._crt_vgg = VGGLoss(vgg=vgg_net) if multi_gpus: self._crt_vgg = torch.nn.DataParallel(self._crt_vgg) self._crt_vgg.cuda() if self._opt.use_style: self._crt_sty = StyleLoss(feat_extractors=vgg_net) if multi_gpus: self._crt_sty = torch.nn.DataParallel(self._crt_sty) self._crt_sty.cuda() if self._opt.use_face: self._crt_face = FaceLoss(pretrained_path=self._opt.face_model) if multi_gpus: self._criterion_face = torch.nn.DataParallel(self._crt_face) self._crt_face.cuda() # init losses G self._g_l1 = self._Tensor([0]) self._g_vgg = self._Tensor([0]) self._g_style = self._Tensor([0]) self._g_face = self._Tensor([0]) self._g_adv = self._Tensor([0]) self._g_smooth = self._Tensor([0]) self._g_mask = self._Tensor([0]) self._g_mask_smooth = self._Tensor([0]) # init losses D self._d_real = self._Tensor([0]) self._d_fake = self._Tensor([0]) @torch.no_grad() def set_input(self, input): images = input['images'] smpls = input['smpls'] aug_bg = input['bg'].cuda() src_img = images[:, 0, ...].contiguous().cuda() src_smpl = smpls[:, 0, ...].contiguous().cuda() tsf_img = images[:, 1, ...].contiguous().cuda() tsf_smpl = smpls[:, 1, ...].contiguous().cuda() input_G_aug_bg, input_G_bg, input_G_src, input_G_tsf, T, bg_mask, head_bbox, body_bbox = \ self._bdr(aug_bg, src_img, src_smpl, tsf_smpl) self._input_G_aug_bg = torch.cat([input_G_bg, input_G_aug_bg], dim=0) self._input_G_src = input_G_src self._input_G_tsf = input_G_tsf self._bg_mask = bg_mask self._T = T self._head_bbox = head_bbox self._body_bbox = body_bbox self._real_src = src_img self._real_tsf = tsf_img self._real_bg = aug_bg def set_train(self): self._G.train() self._D.train() self._is_train = True def set_eval(self): self._G.eval() self._is_train = False def forward(self, keep_data_for_visuals=False, return_estimates=False): # generate fake images fake_aug_bg, fake_src_color, fake_src_mask, fake_tsf_color, fake_tsf_mask = \ self._G.forward(self._input_G_aug_bg, self._input_G_src, self._input_G_tsf, T=self._T) bs = fake_src_color.shape[0] fake_bg = fake_aug_bg[0:bs] fake_src_imgs = fake_src_mask * fake_bg + ( 1 - fake_src_mask) * fake_src_color fake_tsf_imgs = fake_tsf_mask * fake_bg + ( 1 - fake_tsf_mask) * fake_tsf_color fake_masks = torch.cat([fake_src_mask, fake_tsf_mask], dim=0) # keep data for visualization if keep_data_for_visuals: self.visual_imgs(fake_bg, fake_aug_bg, fake_src_imgs, fake_tsf_imgs, fake_masks) # self.visualizer.vis_named_img('fake_aug_bg', fake_aug_bg) # self.visualizer.vis_named_img('fake_aug_bg_input', self._input_G_aug_bg[:, 0:3]) # self.visualizer.vis_named_img('real_bg', self._real_bg) return fake_aug_bg[bs:], fake_src_imgs, fake_tsf_imgs, fake_masks def optimize_parameters(self, trainable=True, keep_data_for_visuals=False): if self._is_train: # convert tensor to variables fake_aug_bg, fake_src_imgs, fake_tsf_imgs, fake_masks = self.forward( keep_data_for_visuals=keep_data_for_visuals) loss_G = self._optimize_G(fake_aug_bg, fake_src_imgs, fake_tsf_imgs, fake_masks) self._optimizer_G.zero_grad() loss_G.backward() self._optimizer_G.step() # train D if trainable: loss_D = self._optimize_D(fake_aug_bg, fake_tsf_imgs) self._optimizer_D.zero_grad() loss_D.backward() self._optimizer_D.step() def _optimize_G(self, fake_aug_bg, fake_src_imgs, fake_tsf_imgs, fake_masks): bs = fake_tsf_imgs.shape[0] fake_global = torch.cat([fake_aug_bg, self._input_G_aug_bg[bs:, -1:]], dim=1) fake_local = torch.cat([fake_tsf_imgs, self._input_G_tsf[:, 3:]], dim=1) d_fake_outs = self._D.forward(fake_global, fake_local, self._body_bbox) self._g_adv = self._compute_loss_D(d_fake_outs, 0) * self._opt.lambda_D_prob self._g_l1 = self._crt_l1(fake_src_imgs, self._real_src) * self._opt.lambda_lp if self._opt.use_vgg: self._g_vgg = torch.mean( self._crt_vgg(fake_tsf_imgs, self._real_tsf) + self._crt_vgg( fake_aug_bg, self._real_bg)) * self._opt.lambda_vgg if self._opt.use_style: self._g_style = torch.mean( self._crt_sty(fake_tsf_imgs, self._real_tsf) + self._crt_sty( fake_aug_bg, self._real_bg)) * self._opt.lambda_style if self._opt.use_face: self._g_face = torch.mean( self._crt_face(fake_tsf_imgs, self._real_tsf, bbox1=self._head_bbox, bbox2=self._head_bbox)) * self._opt.lambda_face # loss mask self._g_mask = self._crt_mask(fake_masks, self._bg_mask) * self._opt.lambda_mask if self._opt.lambda_mask_smooth != 0: self._g_mask_smooth = self._compute_loss_smooth( fake_masks) * self._opt.lambda_mask_smooth # combine losses return self._g_adv + self._g_l1 + self._g_vgg + self._g_style + self._g_face + self._g_mask + self._g_mask_smooth def _optimize_D(self, fake_aug_bg, fake_tsf_imgs): bs = fake_tsf_imgs.shape[0] fake_global = torch.cat( [fake_aug_bg.detach(), self._input_G_aug_bg[bs:, -1:]], dim=1) fake_local = torch.cat( [fake_tsf_imgs.detach(), self._input_G_tsf[:, 3:]], dim=1) real_global = torch.cat( [self._real_bg, self._input_G_aug_bg[bs:, -1:]], dim=1) real_local = torch.cat([self._real_tsf, self._input_G_tsf[:, 3:]], dim=1) d_real_outs = self._D.forward(real_global, real_local, self._body_bbox) d_fake_outs = self._D.forward(fake_global, fake_local, self._body_bbox) _loss_d_real = self._compute_loss_D(d_real_outs, 1) * self._opt.lambda_D_prob _loss_d_fake = self._compute_loss_D(d_fake_outs, -1) * self._opt.lambda_D_prob self._d_real = torch.mean(d_real_outs) self._d_fake = torch.mean(d_fake_outs) # combine losses return _loss_d_real + _loss_d_fake def _compute_loss_D(self, x, y): return torch.mean((x - y)**2) def _compute_loss_smooth(self, mat): return torch.mean(torch.abs(mat[:, :, :, :-1] - mat[:, :, :, 1:])) + \ torch.mean(torch.abs(mat[:, :, :-1, :] - mat[:, :, 1:, :])) def get_current_errors(self): loss_dict = OrderedDict([('g_l1', self._g_l1.item()), ('g_vgg', self._g_vgg.item()), ('g_face', self._g_face.item()), ('g_adv', self._g_adv.item()), ('g_mask', self._g_mask.item()), ('g_mask_smooth', self._g_mask_smooth.item()), ('d_real', self._d_real.item()), ('d_fake', self._d_fake.item())]) return loss_dict def get_current_scalars(self): return OrderedDict([('lr_G', self._current_lr_G), ('lr_D', self._current_lr_D)]) def get_current_visuals(self): # visuals return dictionary visuals = OrderedDict() # inputs visuals['1_real_img'] = self._vis_input visuals['2_input_tsf'] = self._vis_tsf visuals['3_fake_bg'] = self._vis_fake_bg # outputs visuals['4_fake_tsf'] = self._vis_fake_tsf visuals['5_fake_src'] = self._vis_fake_src visuals['6_fake_mask'] = self._vis_mask # batch outputs visuals['7_batch_real_img'] = self._vis_batch_real visuals['8_batch_fake_img'] = self._vis_batch_fake return visuals @torch.no_grad() def visual_imgs(self, fake_bg, fake_aug_bg, fake_src_imgs, fake_tsf_imgs, fake_masks): ids = fake_masks.shape[0] // 2 self._vis_input = util.tensor2im(self._real_src) self._vis_tsf = util.tensor2im(self._input_G_tsf[0, 0:3]) self._vis_fake_bg = util.tensor2im(fake_bg) self._vis_fake_src = util.tensor2im(fake_src_imgs) self._vis_fake_tsf = util.tensor2im(fake_tsf_imgs) self._vis_mask = util.tensor2maskim(fake_masks[ids]) self._vis_batch_real = util.tensor2im(torch.cat( [self._real_tsf, self._real_bg], dim=0), idx=-1) self._vis_batch_fake = util.tensor2im(torch.cat( [fake_tsf_imgs, fake_aug_bg], dim=0), idx=-1) def save(self, label): # save networks self._save_network(self._G, 'G', label) self._save_network(self._D, 'D', label) # save optimizers self._save_optimizer(self._optimizer_G, 'G', label) self._save_optimizer(self._optimizer_D, 'D', label) def load(self): load_epoch = self._opt.load_epoch # load G self._load_network(self._G, 'G', load_epoch, need_module=True) if self._is_train: # load D self._load_network(self._D, 'D', load_epoch, need_module=True) # load optimizers self._load_optimizer(self._optimizer_G, 'G', load_epoch) self._load_optimizer(self._optimizer_D, 'D', load_epoch) def update_learning_rate(self): # updated learning rate G final_lr = self._opt.final_lr lr_decay_G = (self._opt.lr_G - final_lr) / self._opt.nepochs_decay self._current_lr_G -= lr_decay_G for param_group in self._optimizer_G.param_groups: param_group['lr'] = self._current_lr_G print('update G learning rate: %f -> %f' % (self._current_lr_G + lr_decay_G, self._current_lr_G)) # update learning rate D lr_decay_D = (self._opt.lr_D - final_lr) / self._opt.nepochs_decay self._current_lr_D -= lr_decay_D for param_group in self._optimizer_D.param_groups: param_group['lr'] = self._current_lr_D print('update D learning rate: %f -> %f' % (self._current_lr_D + lr_decay_D, self._current_lr_D)) def debug(self, visualizer): visualizer.vis_named_img('bg_inputs', self._input_G_aug_bg[:, 0:3]) ipdb.set_trace()