def _init_losses(self):
        # define loss functions
        multi_gpus = len(self._gpu_ids) > 1
        self._crt_l1 = torch.nn.L1Loss()

        if self._opt.mask_bce:
            self._crt_mask = torch.nn.BCELoss()
        else:
            self._crt_mask = torch.nn.MSELoss()

        vgg_net = Vgg19()
        if self._opt.use_vgg:
            self._criterion_vgg = VGGLoss(vgg=vgg_net)
            if multi_gpus:
                self._criterion_vgg = torch.nn.DataParallel(
                    self._criterion_vgg)
            self._criterion_vgg.cuda()

        if self._opt.use_style:
            self._criterion_style = StyleLoss(feat_extractors=vgg_net)
            if multi_gpus:
                self._criterion_style = torch.nn.DataParallel(
                    self._criterion_style)
            self._criterion_style.cuda()

        if self._opt.use_face:
            self._criterion_face = FaceLoss(
                pretrained_path=self._opt.face_model)
            if multi_gpus:
                self._criterion_face = torch.nn.DataParallel(
                    self._criterion_face)
            self._criterion_face.cuda()

        # init losses G
        self._loss_g_l1 = self._Tensor([0])
        self._loss_g_vgg = self._Tensor([0])
        self._loss_g_style = self._Tensor([0])
        self._loss_g_face = self._Tensor([0])
        self._loss_g_adv = self._Tensor([0])
        self._loss_g_smooth = self._Tensor([0])
        self._loss_g_mask = self._Tensor([0])
        self._loss_g_mask_smooth = self._Tensor([0])

        # init losses D
        self._d_real = self._Tensor([0])
        self._d_fake = self._Tensor([0])
class Impersonator(BaseModel):
    def __init__(self, opt):
        super(Impersonator, self).__init__(opt)
        self._name = 'Impersonator'

        # create networks
        self._init_create_networks()

        # init train variables and losses
        if self._is_train:
            self._init_train_vars()
            self._init_losses()

        # load networks and optimizers
        if not self._is_train or self._opt.load_epoch > 0:
            self.load()
        elif self._opt.load_path != 'None':
            # ipdb.set_trace()
            self._load_params(self._G, self._opt.load_path, need_module=len(self._gpu_ids) > 1)

        # prefetch variables
        self._init_prefetch_inputs()

    def _init_create_networks(self):
        multi_gpus = len(self._gpu_ids) > 1

        # body recovery Flow
        self._bdr = BodyRecoveryFlow(opt=self._opt)
        if multi_gpus:
            self._bdr = torch.nn.DataParallel(self._bdr)

        self._bdr.eval()
        self._bdr.cuda()

        # generator network
        self._G = self._create_generator()
        self._G.init_weights()
        if multi_gpus:
            self._G = torch.nn.DataParallel(self._G)
        self._G.cuda()

        # discriminator network
        self._D = self._create_discriminator()
        self._D.init_weights()
        if multi_gpus:
            self._D = torch.nn.DataParallel(self._D)
        self._D.cuda()

    def _create_generator(self):
        return NetworksFactory.get_by_name(self._opt.gen_name, bg_dim=4, src_dim=3+self._G_cond_nc,
                                           tsf_dim=3+self._G_cond_nc, repeat_num=self._opt.repeat_num)

    def _create_discriminator(self):
        return NetworksFactory.get_by_name('discriminator_patch_gan', input_nc=3 + self._D_cond_nc,
                                           norm_type=self._opt.norm_type, ndf=64, n_layers=4, 
use_sigmoid=False, sn=self._opt.spectral_norm)

    def _init_train_vars(self):
        print("---------- Generator LR:{0} ---------- DISCRIMINATOR LR:{1} ----------".format(self._opt.lr_G, self._opt.lr_D))
        self._current_lr_G = self._opt.lr_G
        self._current_lr_D = self._opt.lr_D

        # initialize optimizers
        self._optimizer_G = torch.optim.Adam(self._G.parameters(), lr=self._current_lr_G,
                                             betas=(self._opt.G_adam_b1, self._opt.G_adam_b2))
        self._optimizer_D = torch.optim.Adam(self._D.parameters(), lr=self._current_lr_D,
                                             betas=(self._opt.D_adam_b1, self._opt.D_adam_b2))

    def _init_prefetch_inputs(self):
        self._real_src = None
        self._real_tsf = None
        self._bg_mask = None
        self._input_src = None
        self._input_G_bg = None
        self._input_G_src = None
        self._input_G_tsf = None
        self._T = None
        self._body_bbox = None
        self._head_bbox = None

    def _init_losses(self):
        # define loss functions
        multi_gpus = len(self._gpu_ids) > 1
        self._crt_l1 = torch.nn.L1Loss()

        if self._opt.mask_bce:
            self._crt_mask = torch.nn.BCELoss()
        else:
            self._crt_mask = torch.nn.MSELoss()

        vgg_net = Vgg19()
        if self._opt.use_vgg:
            self._crt_tsf = VGGLoss(vgg=vgg_net)
            if multi_gpus:
                self._crt_tsf = torch.nn.DataParallel(self._crt_tsf)
            self._crt_tsf.cuda()

        if self._opt.use_style:
            self._crt_style = StyleLoss(feat_extractors=vgg_net)
            if multi_gpus:
                self._crt_style = torch.nn.DataParallel(self._crt_style)
            self._crt_style.cuda()

        if self._opt.use_face:
            self._criterion_face = FaceLoss(pretrained_path=self._opt.face_model)
            if multi_gpus:
                self._criterion_face = torch.nn.DataParallel(self._criterion_face)
            self._criterion_face.cuda()

        # init losses G
        self._loss_g_rec = self._Tensor([0])
        self._loss_g_tsf = self._Tensor([0])
        self._loss_g_style = self._Tensor([0])
        self._loss_g_face = self._Tensor([0])
        self._loss_g_adv = self._Tensor([0])
        self._loss_g_smooth = self._Tensor([0])
        self._loss_g_mask = self._Tensor([0])
        self._loss_g_mask_smooth = self._Tensor([0])

        # init losses D
        self._d_real = self._Tensor([0])
        self._d_fake = self._Tensor([0])
        self._d_real_loss = self._Tensor([0])
        self._d_fake_loss = self._Tensor([0])

    def set_input(self, input):

        with torch.no_grad():
            images = input['images']
            smpls = input['smpls']
            src_img = images[:, 0, ...].cuda()
            src_smpl = smpls[:, 0, ...].cuda()
            tsf_img = images[:, 1, ...].cuda()
            tsf_smpl = smpls[:, 1, ...].cuda()

            input_G_src_bg, input_G_tsf_bg, input_G_src, input_G_tsf, T, src_crop_mask, \
                tsf_crop_mask, head_bbox, body_bbox = self._bdr(src_img, tsf_img, src_smpl, tsf_smpl)

            self._real_src = src_img
            self._real_tsf = tsf_img

            self._bg_mask = torch.cat((src_crop_mask, tsf_crop_mask), dim=0)
            if self._opt.bg_both:
                self._input_G_bg = torch.cat([input_G_src_bg, input_G_tsf_bg], dim=0)
            else:
                self._input_G_bg = input_G_src_bg
            self._input_G_src = input_G_src
            self._input_G_tsf = input_G_tsf
            self._T = T
            self._head_bbox = head_bbox
            self._body_bbox = body_bbox

    def set_train(self):
        self._G.train()
        self._D.train()
        self._is_train = True

    def set_eval(self):
        self._G.eval()
        self._is_train = False

    def forward(self, keep_data_for_visuals=False, return_estimates=False):
        # generate fake images
        fake_bg, fake_src_color, fake_src_mask, fake_tsf_color, fake_tsf_mask = \
            self._G.forward(self._input_G_bg, self._input_G_src, self._input_G_tsf, T=self._T)

        bs = fake_src_color.shape[0]
        fake_src_bg = fake_bg[0:bs]
        if self._opt.bg_both:
            fake_tsf_bg = fake_bg[bs:]
            fake_src_imgs = fake_src_mask * fake_src_bg + (1 - fake_src_mask) * fake_src_color
            fake_tsf_imgs = fake_tsf_mask * fake_tsf_bg + (1 - fake_tsf_mask) * fake_tsf_color
        else:
            fake_src_imgs = fake_src_mask * fake_src_bg + (1 - fake_src_mask) * fake_src_color
            fake_tsf_imgs = fake_tsf_mask * fake_src_bg + (1 - fake_tsf_mask) * fake_tsf_color

        fake_masks = torch.cat([fake_src_mask, fake_tsf_mask], dim=0)

        # keep data for visualization
        if keep_data_for_visuals:
            self.visual_imgs(fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks)

        return fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks

    def optimize_parameters(self, trainable=True, keep_data_for_visuals=False):
        if self._is_train:

            # run inference
            fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks = self.forward(keep_data_for_visuals=keep_data_for_visuals)

            loss_G = self._optimize_G(fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks)

            self._optimizer_G.zero_grad()
            loss_G.backward()
            self._optimizer_G.step()

            # train D
            if trainable:
                loss_D = self._optimize_D(fake_tsf_imgs)
                self._optimizer_D.zero_grad()
                loss_D.backward(retain_graph=True)
                self._optimizer_D.step()

    def _optimize_G(self, fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks):
        fake_input_D = torch.cat([fake_tsf_imgs, self._input_G_tsf[:, 3:]], dim=1)
        d_fake_outs = self._D.forward(fake_input_D)
        self._loss_g_adv = self._compute_loss_D(d_fake_outs, 0) * self._opt.lambda_D_prob

        self._loss_g_rec = self._crt_l1(fake_src_imgs, self._real_src) * self._opt.lambda_rec

        if self._opt.use_vgg:
            self._loss_g_tsf = torch.mean(self._crt_tsf(fake_tsf_imgs, self._real_tsf)) * self._opt.lambda_tsf
        else:
            self._loss_g_tsf = torch.mean(self._crt_tsf(fake_tsf_imgs, self._real_tsf)) * self._opt.lambda_tsf

        if self._opt.use_style:
            self._loss_g_style = torch.mean(self._crt_style(
                fake_tsf_imgs, self._real_tsf)) * self._opt.lambda_style

        if self._opt.use_face:
            self._loss_g_face = torch.mean(self._criterion_face(
                fake_tsf_imgs, self._real_tsf, bbox1=self._head_bbox, bbox2=self._head_bbox)) * self._opt.lambda_face
        # loss mask
        self._loss_g_mask = self._crt_mask(fake_masks, self._bg_mask) * self._opt.lambda_mask

        if self._opt.lambda_mask_smooth != 0:
            self._loss_g_mask_smooth = self._compute_loss_smooth(fake_masks) * self._opt.lambda_mask_smooth

        # combine losses
        return self._loss_g_adv + self._loss_g_rec + self._loss_g_tsf + self._loss_g_style + self._loss_g_face + \
               self._loss_g_mask + self._loss_g_mask_smooth

    def _optimize_D(self, fake_tsf_imgs):
        tsf_cond = self._input_G_tsf[:, 3:]
        fake_input_D = torch.cat([fake_tsf_imgs.detach(), tsf_cond], dim=1)
        real_input_D = torch.cat([self._real_tsf, tsf_cond], dim=1)

        d_real_outs = self._D.forward(real_input_D)
        d_fake_outs = self._D.forward(fake_input_D)

        if self._opt.label_smooth:
            _loss_d_real = self._compute_loss_D(d_real_outs, 0.9) * self._opt.lambda_D_prob
        else:
            _loss_d_real = self._compute_loss_D(d_real_outs, 1) * self._opt.lambda_D_prob

        _loss_d_fake = self._compute_loss_D(d_fake_outs, -1) * self._opt.lambda_D_prob

        self._d_real_loss = _loss_d_real
        self._d_fake_loss = _loss_d_fake

        self._d_real = torch.mean(d_real_outs)
        self._d_fake = torch.mean(d_fake_outs)

        # Gradient Penalty - Puneet
        # gp_weight = 2
        if self._opt.gradient_penalty!=0:
            alpha = torch.rand(real_input_D.shape[0], 1, 1, 1)
            alpha = alpha.expand_as(real_input_D).cuda()
            interp_images = Variable(alpha * real_input_D.data + (1 - alpha) * fake_input_D.data, requires_grad=True).cuda()
            d_interp_outs = self._D.forward(interp_images)
            gradients = torch.autograd.grad(outputs=d_interp_outs, inputs=interp_images,
                                                    grad_outputs=torch.ones(d_interp_outs.size()).cuda(),
                                                    create_graph=True, retain_graph=True)[0]
            gradients = gradients.view(real_input_D.shape[0], -1)
            gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
            gp = gp_weight * gradients_norm.mean()
        else:
            gp = 0
        # combine losses
        return _loss_d_real + _loss_d_fake + gp

    def _compute_loss_D(self, x, y):
        return torch.mean((x - y) ** 2)

    def _compute_loss_smooth(self, mat):
        return torch.mean(torch.abs(mat[:, :, :, :-1] - mat[:, :, :, 1:])) + \
               torch.mean(torch.abs(mat[:, :, :-1, :] - mat[:, :, 1:, :]))

    def get_current_errors(self):
        loss_dict = OrderedDict([('g_rec', self._loss_g_rec.item()),
                                 ('g_tsf', self._loss_g_tsf.item()),
                                 ('g_style', self._loss_g_style.item()),
                                 ('g_face', self._loss_g_face.item()),
                                 ('g_adv', self._loss_g_adv.item()),
                                 ('g_mask', self._loss_g_mask.item()),
                                 ('g_mask_smooth', self._loss_g_mask_smooth.item()),
                                 ('d_real', self._d_real.item()),
                                 ('d_fake', self._d_fake.item()),
                                 ('d_real_loss', self._d_real_loss.item()),
                                 ('d_fake_loss', self._d_fake_loss.item())])

        return loss_dict

    def get_current_scalars(self):
        return OrderedDict([('lr_G', self._current_lr_G), ('lr_D', self._current_lr_D)])

    def get_current_visuals(self):
        # visuals return dictionary
        visuals = OrderedDict()

        # inputs
        visuals['1_real_img'] = self._vis_input
        visuals['2_input_tsf'] = self._vis_tsf
        visuals['3_fake_bg'] = self._vis_fake_bg

        # outputs
        visuals['4_fake_tsf'] = self._vis_fake_tsf
        visuals['5_fake_src'] = self._vis_fake_src
        visuals['6_fake_mask'] = self._vis_mask

        # batch outputs
        visuals['7_batch_real_img'] = self._vis_batch_real
        visuals['8_batch_fake_img'] = self._vis_batch_fake

        return visuals

    @torch.no_grad()
    def visual_imgs(self, fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks):
        ids = fake_masks.shape[0] // 2
        self._vis_input = util.tensor2im(self._real_src)
        self._vis_tsf = util.tensor2im(self._input_G_tsf[0, 0:3])
        self._vis_fake_bg = util.tensor2im(fake_bg)
        self._vis_fake_src = util.tensor2im(fake_src_imgs)
        self._vis_fake_tsf = util.tensor2im(fake_tsf_imgs)
        self._vis_mask = util.tensor2maskim(fake_masks[ids])

        self._vis_batch_real = util.tensor2im(self._real_tsf, idx=-1)
        self._vis_batch_fake = util.tensor2im(fake_tsf_imgs, idx=-1)

    def save(self, label):
        # save networks
        self._save_network(self._G, 'G', label)
        self._save_network(self._D, 'D', label)

        # save optimizers
        self._save_optimizer(self._optimizer_G, 'G', label)
        self._save_optimizer(self._optimizer_D, 'D', label)

    def load(self):
        load_epoch = self._opt.load_epoch

        # load G
        self._load_network(self._G, 'G', load_epoch, need_module=True)

        if self._is_train:
            # load D
            self._load_network(self._D, 'D', load_epoch, need_module=True)

            # load optimizers
            self._load_optimizer(self._optimizer_G, 'G', load_epoch)
            self._load_optimizer(self._optimizer_D, 'D', load_epoch)

    def update_learning_rate(self):
        # updated learning rate G
        final_lr = self._opt.final_lr

        lr_decay_G = (self._opt.lr_G - final_lr) / self._opt.nepochs_decay
        self._current_lr_G -= lr_decay_G
        for param_group in self._optimizer_G.param_groups:
            param_group['lr'] = self._current_lr_G
        print('update G learning rate: %f -> %f' % (self._current_lr_G + lr_decay_G, self._current_lr_G))

        # update learning rate D
        lr_decay_D = (self._opt.lr_D - final_lr) / self._opt.nepochs_decay
        self._current_lr_D -= lr_decay_D
        for param_group in self._optimizer_D.param_groups:
            param_group['lr'] = self._current_lr_D
        print('update D learning rate: %f -> %f' % (self._current_lr_D + lr_decay_D, self._current_lr_D))
class Impersonator(BaseModel):
    def __init__(self, opt):
        super(Impersonator, self).__init__(opt)
        self._name = 'Impersonator'

        # create networks
        self._init_create_networks()

        # init train variables and losses
        if self._is_train:
            self._init_train_vars()
            self._init_losses()

        # load networks and optimizers
        if not self._is_train or self._opt.load_epoch > 0:
            self.load()

        # prefetch variables
        self._init_prefetch_inputs()

    def _init_create_networks(self):
        multi_gpus = len(self._gpu_ids) > 1

        # body recovery Flow
        self._bdr = BodyRecoveryFlow(opt=self._opt)
        if multi_gpus:
            self._bdr = torch.nn.DataParallel(self._bdr)

        self._bdr.eval()
        self._bdr.cuda()

        # generator network
        self._G = self._create_generator()
        self._G.init_weights()
        self._G = torch.nn.DataParallel(self._G)
        self._G.cuda()

        # discriminator network
        self._D = self._create_discriminator()
        self._D.init_weights()
        self._D = torch.nn.DataParallel(self._D)
        self._D.cuda()

    def _create_generator(self):
        return NetworksFactory.get_by_name(self._opt.gen_name,
                                           bg_dim=4,
                                           src_dim=3 + self._G_cond_nc,
                                           tsf_dim=3 + self._G_cond_nc,
                                           repeat_num=self._opt.repeat_num)

    def _create_discriminator(self):
        return NetworksFactory.get_by_name('global_local',
                                           input_nc=3 + self._D_cond_nc // 2,
                                           norm_type=self._opt.norm_type,
                                           ndf=64,
                                           n_layers=4,
                                           use_sigmoid=False)

    def _init_train_vars(self):
        self._current_lr_G = self._opt.lr_G
        self._current_lr_D = self._opt.lr_D

        # initialize optimizers
        self._optimizer_G = torch.optim.Adam(self._G.parameters(),
                                             lr=self._current_lr_G,
                                             betas=(self._opt.G_adam_b1,
                                                    self._opt.G_adam_b2))
        self._optimizer_D = torch.optim.Adam(self._D.parameters(),
                                             lr=self._current_lr_D,
                                             betas=(self._opt.D_adam_b1,
                                                    self._opt.D_adam_b2))

    def _init_prefetch_inputs(self):

        self._real_bg = None
        self._real_src = None
        self._real_tsf = None

        self._bg_mask = None
        self._input_G_aug_bg = None
        self._input_G_src = None
        self._input_G_tsf = None
        self._head_bbox = None
        self._body_bbox = None

        self._T = None

    def _init_losses(self):
        # define loss functions
        multi_gpus = len(self._gpu_ids) > 1
        self._crt_l1 = torch.nn.L1Loss()

        if self._opt.mask_bce:
            self._crt_mask = torch.nn.BCELoss()
        else:
            self._crt_mask = torch.nn.MSELoss()

        vgg_net = Vgg19()
        if self._opt.use_vgg:
            self._crt_vgg = VGGLoss(vgg=vgg_net)
            if multi_gpus:
                self._crt_vgg = torch.nn.DataParallel(self._crt_vgg)
            self._crt_vgg.cuda()

        if self._opt.use_style:
            self._crt_sty = StyleLoss(feat_extractors=vgg_net)
            if multi_gpus:
                self._crt_sty = torch.nn.DataParallel(self._crt_sty)
            self._crt_sty.cuda()

        if self._opt.use_face:
            self._crt_face = FaceLoss(pretrained_path=self._opt.face_model)
            if multi_gpus:
                self._criterion_face = torch.nn.DataParallel(self._crt_face)
            self._crt_face.cuda()

        # init losses G
        self._g_l1 = self._Tensor([0])
        self._g_vgg = self._Tensor([0])
        self._g_style = self._Tensor([0])
        self._g_face = self._Tensor([0])
        self._g_adv = self._Tensor([0])
        self._g_smooth = self._Tensor([0])
        self._g_mask = self._Tensor([0])
        self._g_mask_smooth = self._Tensor([0])

        # init losses D
        self._d_real = self._Tensor([0])
        self._d_fake = self._Tensor([0])

    @torch.no_grad()
    def set_input(self, input):

        images = input['images']
        smpls = input['smpls']
        aug_bg = input['bg'].cuda()
        src_img = images[:, 0, ...].contiguous().cuda()
        src_smpl = smpls[:, 0, ...].contiguous().cuda()
        tsf_img = images[:, 1, ...].contiguous().cuda()
        tsf_smpl = smpls[:, 1, ...].contiguous().cuda()

        input_G_aug_bg, input_G_bg, input_G_src, input_G_tsf, T, bg_mask, head_bbox, body_bbox = \
            self._bdr(aug_bg, src_img, src_smpl, tsf_smpl)

        self._input_G_aug_bg = torch.cat([input_G_bg, input_G_aug_bg], dim=0)
        self._input_G_src = input_G_src
        self._input_G_tsf = input_G_tsf
        self._bg_mask = bg_mask
        self._T = T
        self._head_bbox = head_bbox
        self._body_bbox = body_bbox
        self._real_src = src_img
        self._real_tsf = tsf_img
        self._real_bg = aug_bg

    def set_train(self):
        self._G.train()
        self._D.train()
        self._is_train = True

    def set_eval(self):
        self._G.eval()
        self._is_train = False

    def forward(self, keep_data_for_visuals=False, return_estimates=False):
        # generate fake images
        fake_aug_bg, fake_src_color, fake_src_mask, fake_tsf_color, fake_tsf_mask = \
            self._G.forward(self._input_G_aug_bg, self._input_G_src, self._input_G_tsf, T=self._T)

        bs = fake_src_color.shape[0]
        fake_bg = fake_aug_bg[0:bs]
        fake_src_imgs = fake_src_mask * fake_bg + (
            1 - fake_src_mask) * fake_src_color
        fake_tsf_imgs = fake_tsf_mask * fake_bg + (
            1 - fake_tsf_mask) * fake_tsf_color

        fake_masks = torch.cat([fake_src_mask, fake_tsf_mask], dim=0)

        # keep data for visualization
        if keep_data_for_visuals:
            self.visual_imgs(fake_bg, fake_aug_bg, fake_src_imgs,
                             fake_tsf_imgs, fake_masks)
            # self.visualizer.vis_named_img('fake_aug_bg', fake_aug_bg)
            # self.visualizer.vis_named_img('fake_aug_bg_input', self._input_G_aug_bg[:, 0:3])
            # self.visualizer.vis_named_img('real_bg', self._real_bg)

        return fake_aug_bg[bs:], fake_src_imgs, fake_tsf_imgs, fake_masks

    def optimize_parameters(self, trainable=True, keep_data_for_visuals=False):
        if self._is_train:
            # convert tensor to variables
            fake_aug_bg, fake_src_imgs, fake_tsf_imgs, fake_masks = self.forward(
                keep_data_for_visuals=keep_data_for_visuals)

            loss_G = self._optimize_G(fake_aug_bg, fake_src_imgs,
                                      fake_tsf_imgs, fake_masks)

            self._optimizer_G.zero_grad()
            loss_G.backward()
            self._optimizer_G.step()

            # train D
            if trainable:
                loss_D = self._optimize_D(fake_aug_bg, fake_tsf_imgs)
                self._optimizer_D.zero_grad()
                loss_D.backward()
                self._optimizer_D.step()

    def _optimize_G(self, fake_aug_bg, fake_src_imgs, fake_tsf_imgs,
                    fake_masks):
        bs = fake_tsf_imgs.shape[0]

        fake_global = torch.cat([fake_aug_bg, self._input_G_aug_bg[bs:, -1:]],
                                dim=1)
        fake_local = torch.cat([fake_tsf_imgs, self._input_G_tsf[:, 3:]],
                               dim=1)
        d_fake_outs = self._D.forward(fake_global, fake_local, self._body_bbox)
        self._g_adv = self._compute_loss_D(d_fake_outs,
                                           0) * self._opt.lambda_D_prob

        self._g_l1 = self._crt_l1(fake_src_imgs,
                                  self._real_src) * self._opt.lambda_lp

        if self._opt.use_vgg:
            self._g_vgg = torch.mean(
                self._crt_vgg(fake_tsf_imgs, self._real_tsf) + self._crt_vgg(
                    fake_aug_bg, self._real_bg)) * self._opt.lambda_vgg

        if self._opt.use_style:
            self._g_style = torch.mean(
                self._crt_sty(fake_tsf_imgs, self._real_tsf) + self._crt_sty(
                    fake_aug_bg, self._real_bg)) * self._opt.lambda_style

        if self._opt.use_face:
            self._g_face = torch.mean(
                self._crt_face(fake_tsf_imgs,
                               self._real_tsf,
                               bbox1=self._head_bbox,
                               bbox2=self._head_bbox)) * self._opt.lambda_face
        # loss mask
        self._g_mask = self._crt_mask(fake_masks,
                                      self._bg_mask) * self._opt.lambda_mask

        if self._opt.lambda_mask_smooth != 0:
            self._g_mask_smooth = self._compute_loss_smooth(
                fake_masks) * self._opt.lambda_mask_smooth

        # combine losses
        return self._g_adv + self._g_l1 + self._g_vgg + self._g_style + self._g_face + self._g_mask + self._g_mask_smooth

    def _optimize_D(self, fake_aug_bg, fake_tsf_imgs):
        bs = fake_tsf_imgs.shape[0]
        fake_global = torch.cat(
            [fake_aug_bg.detach(), self._input_G_aug_bg[bs:, -1:]], dim=1)
        fake_local = torch.cat(
            [fake_tsf_imgs.detach(), self._input_G_tsf[:, 3:]], dim=1)
        real_global = torch.cat(
            [self._real_bg, self._input_G_aug_bg[bs:, -1:]], dim=1)
        real_local = torch.cat([self._real_tsf, self._input_G_tsf[:, 3:]],
                               dim=1)

        d_real_outs = self._D.forward(real_global, real_local, self._body_bbox)
        d_fake_outs = self._D.forward(fake_global, fake_local, self._body_bbox)

        _loss_d_real = self._compute_loss_D(d_real_outs,
                                            1) * self._opt.lambda_D_prob
        _loss_d_fake = self._compute_loss_D(d_fake_outs,
                                            -1) * self._opt.lambda_D_prob

        self._d_real = torch.mean(d_real_outs)
        self._d_fake = torch.mean(d_fake_outs)

        # combine losses
        return _loss_d_real + _loss_d_fake

    def _compute_loss_D(self, x, y):
        return torch.mean((x - y)**2)

    def _compute_loss_smooth(self, mat):
        return torch.mean(torch.abs(mat[:, :, :, :-1] - mat[:, :, :, 1:])) + \
               torch.mean(torch.abs(mat[:, :, :-1, :] - mat[:, :, 1:, :]))

    def get_current_errors(self):
        loss_dict = OrderedDict([('g_l1', self._g_l1.item()),
                                 ('g_vgg', self._g_vgg.item()),
                                 ('g_face', self._g_face.item()),
                                 ('g_adv', self._g_adv.item()),
                                 ('g_mask', self._g_mask.item()),
                                 ('g_mask_smooth', self._g_mask_smooth.item()),
                                 ('d_real', self._d_real.item()),
                                 ('d_fake', self._d_fake.item())])

        return loss_dict

    def get_current_scalars(self):
        return OrderedDict([('lr_G', self._current_lr_G),
                            ('lr_D', self._current_lr_D)])

    def get_current_visuals(self):
        # visuals return dictionary
        visuals = OrderedDict()

        # inputs
        visuals['1_real_img'] = self._vis_input
        visuals['2_input_tsf'] = self._vis_tsf
        visuals['3_fake_bg'] = self._vis_fake_bg

        # outputs
        visuals['4_fake_tsf'] = self._vis_fake_tsf
        visuals['5_fake_src'] = self._vis_fake_src
        visuals['6_fake_mask'] = self._vis_mask

        # batch outputs
        visuals['7_batch_real_img'] = self._vis_batch_real
        visuals['8_batch_fake_img'] = self._vis_batch_fake

        return visuals

    @torch.no_grad()
    def visual_imgs(self, fake_bg, fake_aug_bg, fake_src_imgs, fake_tsf_imgs,
                    fake_masks):
        ids = fake_masks.shape[0] // 2
        self._vis_input = util.tensor2im(self._real_src)
        self._vis_tsf = util.tensor2im(self._input_G_tsf[0, 0:3])
        self._vis_fake_bg = util.tensor2im(fake_bg)
        self._vis_fake_src = util.tensor2im(fake_src_imgs)
        self._vis_fake_tsf = util.tensor2im(fake_tsf_imgs)
        self._vis_mask = util.tensor2maskim(fake_masks[ids])

        self._vis_batch_real = util.tensor2im(torch.cat(
            [self._real_tsf, self._real_bg], dim=0),
                                              idx=-1)
        self._vis_batch_fake = util.tensor2im(torch.cat(
            [fake_tsf_imgs, fake_aug_bg], dim=0),
                                              idx=-1)

    def save(self, label):
        # save networks
        self._save_network(self._G, 'G', label)
        self._save_network(self._D, 'D', label)

        # save optimizers
        self._save_optimizer(self._optimizer_G, 'G', label)
        self._save_optimizer(self._optimizer_D, 'D', label)

    def load(self):
        load_epoch = self._opt.load_epoch

        # load G
        self._load_network(self._G, 'G', load_epoch, need_module=True)

        if self._is_train:
            # load D
            self._load_network(self._D, 'D', load_epoch, need_module=True)

            # load optimizers
            self._load_optimizer(self._optimizer_G, 'G', load_epoch)
            self._load_optimizer(self._optimizer_D, 'D', load_epoch)

    def update_learning_rate(self):
        # updated learning rate G
        final_lr = self._opt.final_lr

        lr_decay_G = (self._opt.lr_G - final_lr) / self._opt.nepochs_decay
        self._current_lr_G -= lr_decay_G
        for param_group in self._optimizer_G.param_groups:
            param_group['lr'] = self._current_lr_G
        print('update G learning rate: %f -> %f' %
              (self._current_lr_G + lr_decay_G, self._current_lr_G))

        # update learning rate D
        lr_decay_D = (self._opt.lr_D - final_lr) / self._opt.nepochs_decay
        self._current_lr_D -= lr_decay_D
        for param_group in self._optimizer_D.param_groups:
            param_group['lr'] = self._current_lr_D
        print('update D learning rate: %f -> %f' %
              (self._current_lr_D + lr_decay_D, self._current_lr_D))

    def debug(self, visualizer):
        visualizer.vis_named_img('bg_inputs', self._input_G_aug_bg[:, 0:3])
        ipdb.set_trace()