class GenerationModel(BaseModel):
    def name(self):
        return 'Generation model: pix2pix | pix2pixHD'

    def __init__(self, opt):
        self.t0 = time()
        BaseModel.__init__(self, opt)
        self.train_mode = opt.train_mode
        # resume of networks
        resume_gmm = opt.resume_gmm
        resume_G_parse = opt.resume_G_parse
        resume_D_parse = opt.resume_D_parse
        resume_G_appearance = opt.resume_G_app
        resume_D_appearance = opt.resume_D_app
        resume_G_face = opt.resume_G_face
        resume_D_face = opt.resume_D_face
        # define network
        self.gmm_model = torch.nn.DataParallel(GMM(opt)).cuda()
        self.generator_parsing = Define_G(opt.input_nc_G_parsing,
                                          opt.output_nc_parsing, opt.ndf,
                                          opt.netG_parsing, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          opt.init_gain, opt.gpu_ids)
        self.discriminator_parsing = Define_D(opt.input_nc_D_parsing, opt.ndf,
                                              opt.netD_parsing, opt.n_layers_D,
                                              opt.norm, opt.init_type,
                                              opt.init_gain, opt.gpu_ids)

        self.generator_appearance = Define_G(opt.input_nc_G_app,
                                             opt.output_nc_app,
                                             opt.ndf,
                                             opt.netG_app,
                                             opt.norm,
                                             not opt.no_dropout,
                                             opt.init_type,
                                             opt.init_gain,
                                             opt.gpu_ids,
                                             with_tanh=False)
        self.discriminator_appearance = Define_D(opt.input_nc_D_app, opt.ndf,
                                                 opt.netD_app, opt.n_layers_D,
                                                 opt.norm, opt.init_type,
                                                 opt.init_gain, opt.gpu_ids)

        self.generator_face = Define_G(opt.input_nc_D_face, opt.output_nc_face,
                                       opt.ndf, opt.netG_face, opt.norm,
                                       not opt.no_dropout, opt.init_type,
                                       opt.init_gain, opt.gpu_ids)
        self.discriminator_face = Define_D(opt.input_nc_D_face, opt.ndf,
                                           opt.netD_face, opt.n_layers_D,
                                           opt.norm, opt.init_type,
                                           opt.init_gain, opt.gpu_ids)

        if opt.train_mode == 'gmm':
            setattr(self, 'generator', self.gmm_model)
        else:
            setattr(self, 'generator',
                    getattr(self, 'generator_' + self.train_mode))
            setattr(self, 'discriminator',
                    getattr(self, 'discriminator_' + self.train_mode))

        # load networks
        self.networks_name = [
            'gmm', 'parsing', 'parsing', 'appearance', 'appearance', 'face',
            'face'
        ]
        self.networks_model = [
            self.gmm_model, self.generator_parsing, self.discriminator_parsing,
            self.generator_appearance, self.discriminator_appearance,
            self.generator_face, self.discriminator_face
        ]

        self.networks = dict(zip(self.networks_name, self.networks_model))

        self.resume_path = [
            resume_gmm, resume_G_parse, resume_D_parse, resume_G_appearance,
            resume_D_appearance, resume_G_face, resume_D_face
        ]
        for network, resume in zip(self.networks_model, self.resume_path):
            if network != [] and resume != '':
                assert (osp.exists(resume), 'the resume not exits')
                print('loading...')
                self.load_network(network, resume, ifprint=False)

        # define optimizer
        self.optimizer_gmm = torch.optim.Adam(self.gmm_model.parameters(),
                                              lr=opt.lr,
                                              betas=(0.5, 0.999))

        self.optimizer_parsing_G = torch.optim.Adam(
            self.generator_parsing.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])
        self.optimizer_parsing_D = torch.optim.Adam(
            self.discriminator_parsing.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])

        self.optimizer_appearance_G = torch.optim.Adam(
            self.generator_appearance.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])
        self.optimizer_appearance_D = torch.optim.Adam(
            self.discriminator_appearance.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])

        self.optimizer_face_G = torch.optim.Adam(
            self.generator_face.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])
        self.optimizer_face_D = torch.optim.Adam(
            self.discriminator_face.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])

        if opt.train_mode == 'gmm':
            self.optimizer_G = self.optimizer_gmm

        elif opt.joint_all:
            self.optimizer_G = [
                self.optimizer_parsing_G, self.optimizer_appearance_G,
                self.optimizer_face_G
            ]
            setattr(self, 'optimizer_D',
                    getattr(self, 'optimizer_' + self.train_mode + '_D'))

        else:
            setattr(self, 'optimizer_G',
                    getattr(self, 'optimizer_' + self.train_mode + '_G'))
            setattr(self, 'optimizer_D',
                    getattr(self, 'optimizer_' + self.train_mode + '_D'))

        self.t1 = time()

    def set_input(self, opt, result):

        self.t2 = time()

        self.source_pose_embedding = result['source_pose_embedding'].float(
        ).cuda()
        self.target_pose_embedding = result['target_pose_embedding'].float(
        ).cuda()
        self.source_image = result['source_image'].float().cuda()
        self.target_image = result['target_image'].float().cuda()
        self.source_parse = result['source_parse'].float().cuda()
        self.target_parse = result['target_parse'].float().cuda()
        self.cloth_image = result['cloth_image'].float().cuda()
        self.cloth_parse = result['cloth_parse'].float().cuda()
        self.warped_cloth = result['warped_cloth_image'].float().cuda(
        )  # preprocess warped image from gmm model
        self.target_parse_cloth = result['target_parse_cloth'].float().cuda()
        self.target_pose_img = result['target_pose_img'].float().cuda()
        self.image_without_cloth = create_part(self.source_image,
                                               self.source_parse,
                                               'image_without_cloth', False)

        self.im_c = result['im_c'].float().cuda()  # target warped cloth

        index = [x for x in list(range(20)) if x != 5 and x != 6 and x != 7]
        real_s_ = torch.index_select(self.source_parse, 1,
                                     torch.tensor(index).cuda())
        self.input_parsing = torch.cat(
            (real_s_, self.target_pose_embedding, self.cloth_parse), 1).cuda()

        if opt.train_mode == 'gmm':
            self.im_h = result['im_h'].float().cuda()
            self.source_parse_shape = result['source_parse_shape'].float(
            ).cuda()
            self.agnostic = torch.cat((self.source_parse_shape, self.im_h,
                                       self.target_pose_embedding),
                                      dim=1)

        elif opt.train_mode == 'parsing':
            self.real_s = self.input_parsing
            self.source_parse_vis = result['source_parse_vis'].float().cuda()
            self.target_parse_vis = result['target_parse_vis'].float().cuda()

        elif opt.train_mode == 'appearance':

            if opt.joint_all:
                self.generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
            else:
                with torch.no_grad():
                    self.generated_parsing = F.softmax(
                        self.generator_parsing(self.input_parsing), 1)
            self.input_appearance = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 self.generated_parsing), 1).cuda()

            "attention please"
            generated_parsing_ = torch.argmax(self.generated_parsing,
                                              1,
                                              keepdim=True)
            self.generated_parsing_argmax = torch.Tensor()

            for _ in range(20):
                self.generated_parsing_argmax = torch.cat([
                    self.generated_parsing_argmax.float().cuda(),
                    (generated_parsing_ == _).float()
                ],
                                                          dim=1)
            self.warped_cloth_parse = (
                (generated_parsing_ == 5) + (generated_parsing_ == 6) +
                (generated_parsing_ == 7)).float().cuda()

            if opt.save_time:
                self.generated_parsing_vis = torch.Tensor([0]).expand_as(
                    self.target_image)
            else:
                # decode labels cost much time
                _generated_parsing = torch.argmax(self.generated_parsing,
                                                  1,
                                                  keepdim=True)
                _generated_parsing = _generated_parsing.permute(
                    0, 2, 3, 1).contiguous().int()
                self.generated_parsing_vis = pose_utils.decode_labels(
                    _generated_parsing)  #array

            self.real_s = self.source_image

        elif opt.train_mode == 'face':
            if opt.joint_all:  # opt.joint
                generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
                self.generated_parsing_face = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
            else:
                generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)

                "attention please"
                generated_parsing_ = torch.argmax(generated_parsing,
                                                  1,
                                                  keepdim=True)
                self.generated_parsing_argmax = torch.Tensor()

                for _ in range(20):
                    self.generated_parsing_argmax = torch.cat([
                        self.generated_parsing_argmax.float().cuda(),
                        (generated_parsing_ == _).float()
                    ],
                                                              dim=1)

                # self.generated_parsing_face = generated_parsing_c
                self.generated_parsing_face = self.target_parse

            self.input_appearance = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 generated_parsing), 1).cuda()

            with torch.no_grad():
                self.generated_inter = self.generator_appearance(
                    self.input_appearance)
                p_rendered, m_composite = torch.split(self.generated_inter, 3,
                                                      1)
                p_rendered = F.tanh(p_rendered)
                m_composite = F.sigmoid(m_composite)
                self.generated_image = self.warped_cloth * m_composite + p_rendered * (
                    1 - m_composite)

            self.source_face = create_part(self.source_image,
                                           self.source_parse, 'face', False)
            self.target_face_real = create_part(self.target_image,
                                                self.generated_parsing_face,
                                                'face', False)
            self.target_face_fake = create_part(self.generated_image,
                                                self.generated_parsing_face,
                                                'face', False)
            self.generated_image_without_face = self.generated_image - self.target_face_fake

            self.input_face = torch.cat(
                (self.source_face, self.target_face_fake), 1).cuda()
            self.real_s = self.source_face

        elif opt.train_mode == 'joint':
            self.input_joint = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 self.generated_parsing), 1).cuda()

        self.t3 = time()

        # setattr(self, 'input', getattr(self, 'input_' + self.train_mode))

    def forward(self, opt):
        self.t4 = time()

        if self.train_mode == 'gmm':
            self.grid, self.theta = self.gmm_model(self.agnostic,
                                                   self.cloth_image)
            self.warped_cloth_predict = F.grid_sample(self.cloth_image,
                                                      self.grid)

        if opt.train_mode == 'parsing':
            self.fake_t = F.softmax(self.generator_parsing(self.input_parsing),
                                    dim=1)
            self.real_t = self.target_parse

        if opt.train_mode == 'appearance':
            generated_inter = self.generator_appearance(self.input_appearance)
            p_rendered, m_composite = torch.split(generated_inter, 3, 1)
            p_rendered = F.tanh(p_rendered)
            self.m_composite = F.sigmoid(m_composite)
            p_tryon = self.warped_cloth * self.m_composite + p_rendered * (
                1 - self.m_composite)
            self.fake_t = p_tryon
            self.real_t = self.target_image

            if opt.joint_all:

                generate_face = create_part(self.fake_t,
                                            self.generated_parsing_argmax,
                                            'face', False)
                generate_image_without_face = self.fake_t - generate_face

                real_s_face = create_part(self.source_image, self.source_parse,
                                          'face', False)
                real_t_face = create_part(self.target_image,
                                          self.generated_parsing_argmax,
                                          'face', False)
                input = torch.cat((real_s_face, generate_face), dim=1)

                fake_t_face = self.generator_face(input)
                ###residual learning
                r"""attention
                """
                # fake_t_face = create_part(fake_t_face, self.generated_parsing, 'face', False)
                # fake_t_face = generate_face + fake_t_face
                fake_t_face = create_part(fake_t_face,
                                          self.generated_parsing_argmax,
                                          'face', False)
                ### fake image
                self.fake_t = generate_image_without_face + fake_t_face

        if opt.train_mode == 'face':
            self.fake_t = self.generator_face(self.input_face)

            if opt.face_residual:
                self.fake_t = create_part(self.fake_t,
                                          self.generated_parsing_face, 'face',
                                          False)
                self.fake_t = self.target_face_fake + self.fake_t

            self.fake_t = create_part(self.fake_t, self.generated_parsing_face,
                                      'face', False)
            self.refined_image = self.generated_image_without_face + self.fake_t
            self.real_t = create_part(self.target_image,
                                      self.generated_parsing_face, 'face',
                                      False)

        self.t5 = time()

    def backward_G(self, opt):
        self.t6 = time()

        if opt.train_mode == 'gmm':
            self.loss = self.criterionL1(self.warped_cloth_predict, self.im_c)
            self.loss.backward()
            self.t7 = time()
            return

        fake_st = torch.cat((self.real_s, self.fake_t), 1)
        pred_fake = self.discriminator(fake_st)

        if opt.train_mode == 'parsing':
            self.loss_G_GAN = self.criterionGAN(pred_fake, True)
            self.loss_G_BCE = self.criterionBCE_re(self.fake_t,
                                                   self.real_t) * opt.lambda_L1

            self.loss_G = self.loss_G_GAN + self.loss_G_BCE
            self.loss_G.backward()

        if opt.train_mode == 'appearance':
            self.loss_G_GAN = self.criterionGAN(pred_fake, True) * opt.G_GAN
            # vgg_loss
            loss_vgg1, _ = self.criterion_vgg(self.fake_t, self.real_t,
                                              self.target_parse, False, True,
                                              False)
            loss_vgg2, _ = self.criterion_vgg(self.fake_t, self.real_t,
                                              self.target_parse, False, False,
                                              False)
            self.loss_G_vgg = (loss_vgg1 + loss_vgg2) * opt.G_VGG
            self.loss_G_mask = self.criterionL1(
                self.m_composite, self.warped_cloth_parse) * opt.mask
            if opt.mask_tvloss:
                self.loss_G_mask_tv = self.criterion_tv(self.m_composite)
            else:
                self.loss_G_mask_tv = torch.Tensor([0]).cuda()
            self.loss_G_L1 = self.criterion_smooth_L1(
                self.fake_t, self.real_t) * opt.lambda_L1

            if opt.joint_all and opt.joint_parse_loss:
                self.loss_G_parsing = self.criterionBCE_re(
                    self.generated_parsing,
                    self.target_parse) * opt.joint_G_parsing
                self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.loss_G_vgg + self.loss_G_mask + self.loss_G_parsing
            else:
                self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.loss_G_vgg + self.loss_G_mask + self.loss_G_mask_tv
            self.loss_G.backward()

        if opt.train_mode == 'face':
            _, self.loss_G_vgg = self.criterion_vgg(
                self.fake_t, self.real_t, self.generated_parsing_face, False,
                False, False)  # part, gram, neareast
            self.loss_G_vgg = self.loss_G_vgg * opt.face_vgg
            self.loss_G_L1 = self.criterionL1(self.fake_t,
                                              self.real_t) * opt.face_L1
            self.loss_G_GAN = self.criterionGAN(pred_fake, True) * opt.face_gan
            self.loss_G_refine = self.criterionL1(
                self.refined_image, self.target_image) * opt.face_img_L1

            self.loss_G = self.loss_G_vgg + self.loss_G_L1 + self.loss_G_GAN + self.loss_G_refine
            self.loss_G.backward()

        self.t7 = time()

    def backward_D(self, opt):
        self.t8 = time()

        fake_st = torch.cat((self.real_s, self.fake_t), 1)
        real_st = torch.cat((self.real_s, self.real_t), 1)
        pred_fake = self.discriminator(fake_st.detach())
        pred_real = self.discriminator(real_st)  # batch_size,1, 30,30

        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        self.loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5

        self.loss_D.backward()

        self.t9 = time()

    def optimize_parameters(self, opt):

        self.t10 = time()
        self.forward(opt)  # compute fake images: G(A)

        if opt.train_mode == 'gmm':
            self.optimizer_G.zero_grad()  # set G's gradients to zero
            self.backward_G(opt)  # calculate graidents for G
            self.optimizer_G.step()  # udpate G's weights
            self.t11 = time()
            return

        # update D
        self.set_requires_grad(self.discriminator,
                               True)  # enable backprop for D
        self.optimizer_D.zero_grad()  # set D's gradients to zero
        self.backward_D(opt)  # calculate gradients for D
        self.optimizer_D.step()  # update D's weights
        # update G
        self.set_requires_grad(
            self.discriminator,
            False)  # D requires no gradients when optimizing G
        if opt.joint_all:
            for _ in self.optimizer_G:
                _.zero_grad()

            self.backward_G(opt)

            for _ in self.optimizer_G:
                _.step()
        else:
            self.optimizer_G.zero_grad()  # set G's gradients to zero
            self.backward_G(opt)  # calculate graidents for G
            self.optimizer_G.step()  # udpate G's weights

        self.t11 = time()

    def save_result(self, opt, epoch, iteration):
        if opt.train_mode == 'gmm':
            images = [self.cloth_image, self.warped_cloth.detach(), self.im_c]

        if opt.train_mode == 'parsing':
            fake_t_vis = pose_utils.decode_labels(
                torch.argmax(self.fake_t, dim=1,
                             keepdim=True).permute(0, 2, 3, 1).contiguous())
            images = [
                self.source_parse_vis, self.target_parse_vis,
                self.target_pose_img, self.cloth_parse, fake_t_vis
            ]

        if opt.train_mode == 'appearance':
            images = [
                self.image_without_cloth, self.warped_cloth,
                self.warped_cloth_parse, self.target_image, self.cloth_image,
                self.generated_parsing_vis,
                self.fake_t.detach()
            ]

        if opt.train_mode == 'face':
            images = [
                self.generated_image.detach(),
                self.refined_image.detach(), self.source_image,
                self.target_image, self.real_t,
                self.fake_t.detach()
            ]

        pose_utils.save_img(
            images,
            os.path.join(self.vis_path,
                         str(epoch) + '_' + str(iteration) + '.jpg'))

    def save_model(self, opt, epoch):
        if opt.train_mode == 'gmm':
            model_G = osp.join(
                self.save_dir, 'generator',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))

            if not osp.exists(osp.join(self.save_dir, 'generator')):
                os.makedirs(osp.join(self.save_dir, 'generator'))

            torch.save(self.generator.state_dict(), model_G)

        elif not opt.joint_all:
            model_G = osp.join(
                self.save_dir, 'generator',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))
            model_D = osp.join(
                self.save_dir, 'dicriminator',
                'checkpoint_D_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_D))
            if not osp.exists(osp.join(self.save_dir, 'generator')):
                os.makedirs(osp.join(self.save_dir, 'generator'))
            if not osp.exists(osp.join(self.save_dir, 'dicriminator')):
                os.makedirs(osp.join(self.save_dir, 'dicriminator'))

            torch.save(self.generator.state_dict(), model_G)
            torch.save(self.discriminator.state_dict(), model_D)
        else:
            model_G_parsing = osp.join(
                self.save_dir, 'generator_parsing',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))
            model_D_parsing = osp.join(
                self.save_dir, 'dicriminator_parsing',
                'checkpoint_D_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_D))

            model_G_appearance = osp.join(
                self.save_dir, 'generator_appearance',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))
            model_D_appearance = osp.join(
                self.save_dir, 'dicriminator_appearance',
                'checkpoint_D_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_D))

            model_G_face = osp.join(
                self.save_dir, 'generator_face',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))
            model_D_face = osp.join(
                self.save_dir, 'dicriminator_face',
                'checkpoint_D_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_D))

            joint_save_dirs = [
                osp.join(self.save_dir, 'generator_parsing'),
                osp.join(self.save_dir, 'dicriminator_parsing'),
                osp.join(self.save_dir, 'generator_appearance'),
                osp.join(self.save_dir, 'dicriminator_appearance'),
                osp.join(self.save_dir, 'generator_face'),
                osp.join(self.save_dir, 'dicriminator_face')
            ]
            for _ in joint_save_dirs:
                if not osp.exists(_):
                    os.makedirs(_)
            torch.save(self.generator_parsing.state_dict(), model_G_parsing)
            torch.save(self.generator_appearance.state_dict(),
                       model_G_appearance)
            torch.save(self.generator_face.state_dict(), model_G_face)
            torch.save(self.discriminator_appearance.state_dict(),
                       model_D_appearance)

    def print_current_errors(self, opt, epoch, i):
        if opt.train_mode == 'gmm':
            errors = {'loss_L1': self.loss.item()}

        if opt.train_mode == 'appearance':
            errors = {
                'loss_G': self.loss_G.item(),
                'loss_G_GAN': self.loss_G_GAN.item(),
                'loss_G_vgg': self.loss_G_vgg.item(),
                'loss_G_mask': self.loss_G_mask.item(),
                'loss_G_L1': self.loss_G_L1.item(),
                'loss_D': self.loss_D.item(),
                'loss_D_real': self.loss_D_real.item(),
                'loss_D_fake': self.loss_D_real.item(),
                'loss_G_mask_tv': self.loss_G_mask_tv.item()
            }

            if opt.joint_all and opt.joint_parse_loss:
                errors = {
                    'loss_G': self.loss_G.item(),
                    'loss_G_GAN': self.loss_G_GAN.item(),
                    'loss_G_vgg': self.loss_G_vgg.item(),
                    'loss_G_mask': self.loss_G_mask.item(),
                    'loss_G_L1': self.loss_G_L1.item(),
                    'loss_D': self.loss_D.item(),
                    'loss_D_real': self.loss_D_real.item(),
                    'loss_D_fake': self.loss_D_real.item(),
                    'loss_G_parsing': self.loss_G_parsing.item()
                }

        if opt.train_mode == 'parsing':
            errors = {
                'loss_G': self.loss_G.item(),
                'loss_G_GAN': self.loss_G_GAN.item(),
                'loss_G_BCE': self.loss_G_BCE.item(),
                'loss_D': self.loss_D.item(),
                'loss_D_real': self.loss_D_real.item(),
                'loss_D_fake': self.loss_D_real.item()
            }

        if opt.train_mode == 'face':
            errors = {
                'loss_G': self.loss_G.item(),
                'loss_G_GAN': self.loss_G_GAN.item(),
                'loss_G_vgg': self.loss_G_vgg.item(),
                'loss_G_refine': self.loss_G_refine.item(),
                'loss_G_L1': self.loss_G_L1.item(),
                'loss_D': self.loss_D.item(),
                'loss_D_real': self.loss_D_real.item(),
                'loss_D_fake': self.loss_D_real.item()
            }

        t = self.t11 - self.t2
        message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
        for k, v in sorted(errors.items()):
            if v != 0:
                message += '%s: %.3f ' % (k, v)

        print(message)
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)
class GenerationModel(BaseModel):
    # Returns the name of the Network
    def name(self):
        return 'Generation model: pix2pix | pix2pixHD'

    # Init Function
    def __init__(self, opt):
        self.t0 = time()

        BaseModel.__init__(self, opt)
        self.train_mode = opt.train_mode

        # Resume of networks
        resume_gmm = opt.resume_gmm
        resume_G_parse = opt.resume_G_parse
        resume_D_parse = opt.resume_D_parse
        resume_G_appearance = opt.resume_G_app
        resume_D_appearance = opt.resume_D_app
        resume_G_face = opt.resume_G_face
        resume_D_face = opt.resume_D_face

        # Define network
        self.gmm_model = torch.nn.DataParallel(GMM(opt)).cuda()
        self.generator_parsing = Define_G(opt.input_nc_G_parsing,
                                          opt.output_nc_parsing, opt.ndf,
                                          opt.netG_parsing, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          opt.init_gain, opt.gpu_ids)
        self.discriminator_parsing = Define_D(opt.input_nc_D_parsing, opt.ndf,
                                              opt.netD_parsing, opt.n_layers_D,
                                              opt.norm, opt.init_type,
                                              opt.init_gain, opt.gpu_ids)
        self.generator_appearance = Define_G(opt.input_nc_G_app,
                                             opt.output_nc_app,
                                             opt.ndf,
                                             opt.netG_app,
                                             opt.norm,
                                             not opt.no_dropout,
                                             opt.init_type,
                                             opt.init_gain,
                                             opt.gpu_ids,
                                             with_tanh=False)
        self.discriminator_appearance = Define_D(opt.input_nc_D_app, opt.ndf,
                                                 opt.netD_app, opt.n_layers_D,
                                                 opt.norm, opt.init_type,
                                                 opt.init_gain, opt.gpu_ids)
        self.generator_face = Define_G(opt.input_nc_D_face, opt.output_nc_face,
                                       opt.ndf, opt.netG_face, opt.norm,
                                       not opt.no_dropout, opt.init_type,
                                       opt.init_gain, opt.gpu_ids)
        self.discriminator_face = Define_D(opt.input_nc_D_face, opt.ndf,
                                           opt.netD_face, opt.n_layers_D,
                                           opt.norm, opt.init_type,
                                           opt.init_gain, opt.gpu_ids)

        ######################################
        # HELPER COMMENTS
        # when we train each network seperately we use generator and discriminator according to train_mode
        # when we train jointly we train appearance network generator and discriminator
        # when we train jointly parsing/ face generator network gets trained only
        # self.generator/discriminator optimizer also follows the same strategy and  comes from train mode variable
        ######################################
        print("Train Mode is --- ", opt.train_mode)

        if opt.train_mode == 'gmm':
            setattr(self, 'generator', self.gmm_model)
        else:
            setattr(self, 'generator',
                    getattr(self, 'generator_' + self.train_mode))
            setattr(self, 'discriminator',
                    getattr(self, 'discriminator_' + self.train_mode))

        ######################################
        # Load Networks
        ######################################
        self.networks_name = [
            'gmm', 'G_parsing', 'D_parsing', 'G_appearance', 'D_appearance',
            'G_face', 'D_face'
        ]
        self.networks_model = [
            self.gmm_model, self.generator_parsing, self.discriminator_parsing,
            self.generator_appearance, self.discriminator_appearance,
            self.generator_face, self.discriminator_face
        ]
        self.networks = dict(zip(self.networks_name, self.networks_model))
        self.resume_path = [
            resume_gmm, resume_G_parse, resume_D_parse, resume_G_appearance,
            resume_D_appearance, resume_G_face, resume_D_face
        ]

        for index, (network, resume) in enumerate(
                zip(self.networks_model, self.resume_path)):
            if osp.exists(resume):
                assert osp.exists(resume), 'the resume not exits'
                print('loading...{}'.format(self.networks_name[index]))
                self.load_network(network, resume, ifprint=False)

        ######################################
        # HELPER COMMENTS
        # optimizer_G/optimizer_D gets set according to train mode
        # for joint training optimizer_G is parsing,appearance,face network and optimizer_D is the appearance network
        ######################################

        self.optimizer_gmm = torch.optim.Adam(self.gmm_model.parameters(),
                                              lr=opt.lr,
                                              betas=(0.5, 0.999))

        self.optimizer_parsing_G = torch.optim.Adam(
            self.generator_parsing.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])
        self.optimizer_parsing_D = torch.optim.Adam(
            self.discriminator_parsing.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])

        self.optimizer_appearance_G = torch.optim.Adam(
            self.generator_appearance.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])
        self.optimizer_appearance_D = torch.optim.Adam(
            self.discriminator_appearance.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])

        self.optimizer_face_G = torch.optim.Adam(
            self.generator_face.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])
        self.optimizer_face_D = torch.optim.Adam(
            self.discriminator_face.parameters(),
            lr=opt.lr,
            betas=[opt.beta1, 0.999])

        if opt.train_mode == 'gmm':
            self.optimizer_G = self.optimizer_gmm

        elif opt.joint_all:
            self.optimizer_G = [
                self.optimizer_parsing_G, self.optimizer_appearance_G,
                self.optimizer_face_G
            ]
            setattr(self, 'optimizer_D',
                    getattr(self, 'optimizer_' + self.train_mode + '_D'))

        else:
            setattr(self, 'optimizer_G',
                    getattr(self, 'optimizer_' + self.train_mode + '_G'))
            setattr(self, 'optimizer_D',
                    getattr(self, 'optimizer_' + self.train_mode + '_D'))

        # Tensorboard

        if opt.train_mode == 'gmm':
            if not osp.exists(osp.join(self.save_dir, 'GMM_tboard')):
                os.makedirs(osp.join(self.save_dir, 'GMM_tboard'))
            self.writer = SummaryWriter(osp.join(self.save_dir, 'GMM_tboard'))

        elif opt.train_mode == 'parsing':
            if not osp.exists(osp.join(self.save_dir, 'parsing_tboard')):
                os.makedirs(osp.join(self.save_dir, 'parsing_tboard'))
            self.writer = SummaryWriter(
                osp.join(self.save_dir, 'parsing_tboard'))

        elif opt.train_mode == 'appearance':

            if opt.joint_all:
                if not osp.exists(osp.join(self.save_dir, 'joint_tboard')):
                    os.makedirs(osp.join(self.save_dir, 'joint_tboard'))
                self.writer = SummaryWriter(
                    osp.join(self.save_dir, 'joint_tboard'))
            else:
                if not osp.exists(osp.join(self.save_dir,
                                           'appearance_tboard')):
                    os.makedirs(osp.join(self.save_dir, 'appearance_tboard'))
                self.writer = SummaryWriter(
                    osp.join(self.save_dir, 'appearance_tboard'))

        elif opt.train_mode == 'face':
            if not osp.exists(osp.join(self.save_dir, 'face_tboard')):
                os.makedirs(osp.join(self.save_dir, 'face_tboard'))
            self.writer = SummaryWriter(osp.join(self.save_dir, 'face_tboard'))

        self.t1 = time()

    # Set the inputs according to the models
    def set_input(self, opt, result):

        self.t2 = time()

        # Input data returned by dataloader
        self.source_pose_embedding = result['source_pose_embedding'].float(
        ).cuda()
        self.target_pose_embedding = result['target_pose_embedding'].float(
        ).cuda()
        self.source_densepose_data = result['source_densepose_data'].float(
        ).cuda()
        self.target_densepose_data = result['target_densepose_data'].float(
        ).cuda()
        self.source_image = result['source_image'].float().cuda()
        self.target_image = result['target_image'].float().cuda()
        self.source_parse = result['source_parse'].float().cuda()
        self.target_parse = result['target_parse'].float().cuda()
        self.cloth_image = result['cloth_image'].float().cuda()
        self.cloth_parse = result['cloth_parse'].float().cuda()
        # self.warped_cloth = result['warped_cloth_image'].float().cuda() # preprocess warped image from gmm model
        self.target_parse_cloth = result['target_parse_cloth'].float().cuda()
        self.target_pose_img = result['target_pose_img']
        self.image_without_cloth = create_part(self.source_image,
                                               self.source_parse,
                                               'image_without_cloth', False)
        self.im_c = result['im_c'].float().cuda()  # target warped cloth

        # input_parsing input to the parsing transformation network
        index = [x for x in list(range(20)) if x != 5 and x != 6 and x != 7]
        real_s_ = torch.index_select(self.source_parse, 1,
                                     torch.tensor(index).cuda())
        self.input_parsing = torch.cat(
            (real_s_, self.target_densepose_data, self.cloth_parse), 1).cuda()

        if opt.train_mode != 'parsing' and opt.train_mode != 'gmm':
            self.warped_cloth = warped_image(self.gmm_model, result)

        ######################################
        # Part 1 GMM
        ######################################
        # For GMM training we need agnostic cloth_represent(source_head, densepose) original_cloth (from dataloader)
        if opt.train_mode == 'gmm':
            self.im_h = result['im_h'].float().cuda()
            self.source_parse_shape = result['source_parse_shape'].float(
            ).cuda()
            self.agnostic = torch.cat((self.source_parse_shape, self.im_h,
                                       self.target_pose_embedding),
                                      dim=1)

        ######################################
        # Part 2 PARSING
        ######################################
        # For parsing training
        # Input  input_parsing
        # output is the target parse
        elif opt.train_mode == 'parsing':
            self.real_s = self.input_parsing
            self.source_parse_vis = result['source_parse_vis'].float().cuda()
            self.target_parse_vis = result['target_parse_vis'].float().cuda()

        ######################################
        # Part 3 APPEARANCE
        ######################################
        # For appearance training
        # Input generated parse + warped_cloth + generated_parsing
        # Output corse render image(compare with target image) and composition mask (compare with warped_cloth_parse(this is generated from parsing network))

        elif opt.train_mode == 'appearance':

            # If join all training then train flow gradients else don't flow
            if opt.joint_all:
                self.generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
            else:
                with torch.no_grad():
                    self.generated_parsing = F.softmax(
                        self.generator_parsing(self.input_parsing), 1)

            # Input to the generated appearance network
            self.input_appearance = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 self.generated_parsing), 1).cuda()
            "attention please"
            generated_parsing_ = torch.argmax(self.generated_parsing,
                                              1,
                                              keepdim=True)

            # input to the generator appearance
            self.generated_parsing_argmax = torch.Tensor()

            # create the warped_cloth_parse from the parsing network
            for _ in range(20):
                self.generated_parsing_argmax = torch.cat([
                    self.generated_parsing_argmax.float().cuda(),
                    (generated_parsing_ == _).float()
                ],
                                                          dim=1)
            self.warped_cloth_parse = (
                (generated_parsing_ == 5) + (generated_parsing_ == 6) +
                (generated_parsing_ == 7)).float().cuda()

            # For visualization
            if opt.save_time:
                self.generated_parsing_vis = torch.Tensor([0]).expand_as(
                    self.target_image)
            else:
                # decode labels cost much time
                _generated_parsing = torch.argmax(self.generated_parsing,
                                                  1,
                                                  keepdim=True)
                _generated_parsing = _generated_parsing.permute(
                    0, 2, 3, 1).contiguous().int()
                self.generated_parsing_vis = pose_utils.decode_labels(
                    _generated_parsing)  # array

            # For gan training
            self.real_s = self.source_image

        ######################################
        # Part 4 FACE
        ######################################

        elif opt.train_mode == 'face':
            if opt.joint_all:
                generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
                self.generated_parsing_face = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)
            else:
                generated_parsing = F.softmax(
                    self.generator_parsing(self.input_parsing), 1)

                "attention please"
                generated_parsing_ = torch.argmax(generated_parsing,
                                                  1,
                                                  keepdim=True)
                self.generated_parsing_argmax = torch.Tensor()

                for _ in range(20):
                    self.generated_parsing_argmax = torch.cat([
                        self.generated_parsing_argmax.float().cuda(),
                        (generated_parsing_ == _).float()
                    ],
                                                              dim=1)

                # self.generated_parsing_face = generated_parsing_c
                self.generated_parsing_face = self.target_parse

            self.input_appearance = torch.cat(
                (self.image_without_cloth, self.warped_cloth,
                 generated_parsing), 1).cuda()

            with torch.no_grad():
                self.generated_inter = self.generator_appearance(
                    self.input_appearance)
                p_rendered, m_composite = torch.split(self.generated_inter, 3,
                                                      1)
                p_rendered = F.tanh(p_rendered)
                m_composite = F.sigmoid(m_composite)
                self.generated_image = self.warped_cloth * \
                    m_composite + p_rendered * (1 - m_composite)

            self.source_face = create_part(self.source_image,
                                           self.source_parse, 'face', False)
            self.target_face_real = create_part(self.target_image,
                                                self.generated_parsing_face,
                                                'face', False)
            self.target_face_fake = create_part(self.generated_image,
                                                self.generated_parsing_face,
                                                'face', False)
            self.generated_image_without_face = self.generated_image - self.target_face_fake

            self.input_face = torch.cat(
                (self.source_face, self.target_face_fake), 1).cuda()
            self.real_s = self.source_face

        self.t3 = time()

    # All Forward operations of the networks
    def forward(self, opt):
        self.t4 = time()

        ######################################
        # Part 1 GMM Forward
        ######################################
        if self.train_mode == 'gmm':
            self.grid, self.theta = self.gmm_model(self.agnostic,
                                                   self.cloth_image)
            self.warped_cloth_predict = F.grid_sample(self.cloth_image,
                                                      self.grid)

        ######################################
        # Part 2 PARSING Forward
        ######################################
        if opt.train_mode == 'parsing':
            self.fake_t = F.softmax(self.generator_parsing(self.input_parsing),
                                    dim=1)
            self.real_t = self.target_parse

        ######################################
        # Part 3 APPEARANCE Forward
        ######################################
        if opt.train_mode == 'appearance':
            generated_inter = self.generator_appearance(self.input_appearance)
            p_rendered, m_composite = torch.split(generated_inter, 3, 1)
            p_rendered = F.tanh(p_rendered)
            self.m_composite = F.sigmoid(m_composite)
            p_tryon = self.warped_cloth * self.m_composite + \
                p_rendered * (1 - self.m_composite)
            self.fake_t = p_tryon
            self.real_t = self.target_image

            if opt.joint_all:

                generate_face = create_part(self.fake_t,
                                            self.generated_parsing_argmax,
                                            'face', False)
                generate_image_without_face = self.fake_t - generate_face

                real_s_face = create_part(self.source_image, self.source_parse,
                                          'face', False)
                real_t_face = create_part(self.target_image,
                                          self.generated_parsing_argmax,
                                          'face', False)
                input = torch.cat((real_s_face, generate_face), dim=1)

                fake_t_face = self.generator_face(input)
                # residual learning
                r"""attention
                """
                # fake_t_face = create_part(fake_t_face, self.generated_parsing, 'face', False)
                # fake_t_face = generate_face + fake_t_face
                fake_t_face = create_part(fake_t_face,
                                          self.generated_parsing_argmax,
                                          'face', False)
                # fake image
                self.fake_t = generate_image_without_face + fake_t_face

        ######################################
        # Part 4 FACE Forward
        ######################################
        if opt.train_mode == 'face':
            self.fake_t = self.generator_face(self.input_face)

            if opt.face_residual:
                self.fake_t = create_part(self.fake_t,
                                          self.generated_parsing_face, 'face',
                                          False)
                self.fake_t = self.target_face_fake + self.fake_t

            self.fake_t = create_part(self.fake_t, self.generated_parsing_face,
                                      'face', False)
            self.refined_image = self.generated_image_without_face + self.fake_t
            self.real_t = create_part(self.target_image,
                                      self.generated_parsing_face, 'face',
                                      False)

        self.t5 = time()

    # All back propagation and loss operations of the GMM and Generator networks
    def backward_G(self, opt):
        self.t6 = time()

        ######################################
        # Part 1 GMM Loss Function
        ######################################

        if opt.train_mode == 'gmm':
            self.loss = self.criterionL1(self.warped_cloth_predict, self.im_c)
            self.loss.backward()
            self.t7 = time()
            return

        else:
            fake_st = torch.cat((self.real_s, self.fake_t), 1)
            pred_fake = self.discriminator(fake_st)

            ######################################
            # Part 2 PARSING Loss Function
            ######################################
            # gan loss + binary cross entropy loss

            if opt.train_mode == 'parsing':
                self.loss_G_GAN = self.criterionGAN(pred_fake, True)
                self.loss_G_BCE = self.criterionBCE_re(
                    self.fake_t, self.real_t) * opt.lambda_L1

                self.loss_G = self.loss_G_GAN + self.loss_G_BCE
                self.loss_G.backward()

            ######################################
            # Part 3 APPEARANCE Loss Function
            ######################################
            # loss l1 + loss gan + loss mask + loss vgg
            # loss l1 + loss gan + loss mask + loss vgg  + loss parsing if join_all training mode
            if opt.train_mode == 'appearance':
                self.loss_G_GAN = self.criterionGAN(pred_fake,
                                                    True) * opt.G_GAN
                # vgg_loss
                loss_vgg1, _ = self.criterion_vgg(self.fake_t, self.real_t,
                                                  self.target_parse, False,
                                                  True, False)
                loss_vgg2, _ = self.criterion_vgg(self.fake_t, self.real_t,
                                                  self.target_parse, False,
                                                  False, False)
                self.loss_G_vgg = (loss_vgg1 + loss_vgg2) * opt.G_VGG
                self.loss_G_mask = self.criterionL1(
                    self.m_composite, self.warped_cloth_parse) * opt.mask
                if opt.mask_tvloss:
                    self.loss_G_mask_tv = self.criterion_tv(self.m_composite)
                else:
                    self.loss_G_mask_tv = torch.Tensor([0]).cuda()
                self.loss_G_L1 = self.criterion_smooth_L1(
                    self.fake_t, self.real_t) * opt.lambda_L1

                if opt.joint_all and opt.joint_parse_loss:
                    self.loss_G_parsing = self.criterionBCE_re(
                        self.generated_parsing,
                        self.target_parse) * opt.joint_G_parsing
                    self.loss_G = self.loss_G_GAN + self.loss_G_L1 + \
                        self.loss_G_vgg + self.loss_G_mask + self.loss_G_parsing
                else:
                    self.loss_G = self.loss_G_GAN + self.loss_G_L1 + \
                        self.loss_G_vgg + self.loss_G_mask + self.loss_G_mask_tv
                self.loss_G.backward()

            ######################################
            # Part 4 FACE Loss Function
            ######################################
            if opt.train_mode == 'face':
                _, self.loss_G_vgg = self.criterion_vgg(
                    self.fake_t, self.real_t, self.generated_parsing_face,
                    False, False, False)  # part, gram, neareast
                self.loss_G_vgg = self.loss_G_vgg * opt.face_vgg
                self.loss_G_L1 = self.criterionL1(self.fake_t,
                                                  self.real_t) * opt.face_L1
                self.loss_G_GAN = self.criterionGAN(pred_fake,
                                                    True) * opt.face_gan
                self.loss_G_refine = self.criterionL1(
                    self.refined_image, self.target_image) * opt.face_img_L1

                self.loss_G = self.loss_G_vgg + self.loss_G_L1 + \
                    self.loss_G_GAN + self.loss_G_refine
                self.loss_G.backward()

        self.t7 = time()

# All back propagation and loss operations Discriminator networks

    def backward_D(self, opt):
        self.t8 = time()

        fake_st = torch.cat((self.real_s, self.fake_t), 1)
        real_st = torch.cat((self.real_s, self.real_t), 1)
        pred_fake = self.discriminator(fake_st.detach())
        pred_real = self.discriminator(real_st)

        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        self.loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5

        self.loss_D.backward()

        self.t9 = time()

    # All optimizer operations of the networks
    def optimize_parameters(self, opt):

        self.t10 = time()

        # Forward Function
        self.forward(opt)

        ######################################
        # Part 1 GMM Optimizer
        ######################################

        if opt.train_mode == 'gmm':
            self.optimizer_G.zero_grad()  # set G's gradients to zero
            self.backward_G(opt)  # calculate graidents for G
            self.optimizer_G.step()  # udpate G's weights
            self.t11 = time()
            return

        ######################################
        # Part 2 PARSING/APPEARANCE/FACE Network Optimizer For Generator And Discriminator
        ######################################
        else:
            # Update the discriminator
            # enable backprop for D
            self.set_requires_grad(self.discriminator, True)
            self.optimizer_D.zero_grad()  # set D's gradients to zero
            # calculate gradients for D
            self.backward_D(opt)
            self.optimizer_D.step()  # update D's weights

            # update the generator
            # D requires no gradients when optimizing G
            self.set_requires_grad(self.discriminator, False)
            if opt.joint_all:
                for _ in self.optimizer_G:
                    _.zero_grad()

                self.backward_G(opt)

                for _ in self.optimizer_G:
                    _.step()
            else:
                self.optimizer_G.zero_grad()  # set G's gradients to zero
                self.backward_G(opt)  # calculate graidents for G
                self.optimizer_G.step()  # udpate G's weights

        self.t11 = time()

    # Saving the images for visualization for training and testing purposes
    def save_result(self, opt, epoch, iteration):

        ######################################
        # Part 1 GMM Results
        ######################################
        if opt.train_mode == 'gmm':
            images = [
                self.target_pose_img, self.cloth_image, self.im_c,
                self.warped_cloth_predict.detach()
            ]

        ######################################
        # Part 2 PARSING Results
        ######################################
        if opt.train_mode == 'parsing':
            fake_t_vis = pose_utils.decode_labels(
                torch.argmax(self.fake_t, dim=1,
                             keepdim=True).permute(0, 2, 3, 1).contiguous())
            images = [
                self.source_parse_vis, self.target_parse_vis,
                self.target_densepose_data, self.cloth_parse, fake_t_vis
            ]

        ######################################
        # Part 3 APPEARANCE Results
        ######################################
        if opt.train_mode == 'appearance':
            images = [
                self.image_without_cloth, self.warped_cloth,
                self.warped_cloth_parse, self.target_image, self.cloth_image,
                self.generated_parsing_vis,
                self.fake_t.detach()
            ]

        ######################################
        # Part 4 FACE Results
        ######################################
        if opt.train_mode == 'face':
            images = [
                self.generated_image.detach(),
                self.refined_image.detach(), self.source_image,
                self.target_image, self.real_t,
                self.fake_t.detach()
            ]

        pose_utils.save_img(
            images,
            os.path.join(self.vis_path,
                         str(epoch) + '_' + str(iteration) + '.jpg'))

    # Save the trained models
    def save_model(self, opt, epoch):

        ######################################
        # Part 1 GMM Model Save
        ######################################
        if opt.train_mode == 'gmm':
            model_G = osp.join(
                self.save_dir, 'generator',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss))

            if not osp.exists(osp.join(self.save_dir, 'generator')):
                os.makedirs(osp.join(self.save_dir, 'generator'))
            torch.save(self.generator.state_dict(), model_G)

        ######################################
        # Part 2 PARSING/APPEARANCE/FACE Model Save
        ######################################
        elif not opt.joint_all:
            model_G = osp.join(
                self.save_dir, 'generator',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))
            model_D = osp.join(
                self.save_dir, 'dicriminator',
                'checkpoint_D_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_D))
            if not osp.exists(osp.join(self.save_dir, 'generator')):
                os.makedirs(osp.join(self.save_dir, 'generator'))
            if not osp.exists(osp.join(self.save_dir, 'dicriminator')):
                os.makedirs(osp.join(self.save_dir, 'dicriminator'))

            torch.save(self.generator.state_dict(), model_G)
            torch.save(self.discriminator.state_dict(), model_D)

        ######################################
        # Part 2 JOINT  Model Save
        ######################################
        else:
            model_G_parsing = osp.join(
                self.save_dir, 'generator_parsing',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))
            model_D_parsing = osp.join(
                self.save_dir, 'dicriminator_parsing',
                'checkpoint_D_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_D))

            model_G_appearance = osp.join(
                self.save_dir, 'generator_appearance',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))
            model_D_appearance = osp.join(
                self.save_dir, 'dicriminator_appearance',
                'checkpoint_D_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_D))

            model_G_face = osp.join(
                self.save_dir, 'generator_face',
                'checkpoint_G_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_G))
            model_D_face = osp.join(
                self.save_dir, 'dicriminator_face',
                'checkpoint_D_epoch_%d_loss_%0.5f_pth.tar' %
                (epoch, self.loss_D))

            joint_save_dirs = [
                osp.join(self.save_dir, 'generator_parsing'),
                osp.join(self.save_dir, 'dicriminator_parsing'),
                osp.join(self.save_dir, 'generator_appearance'),
                osp.join(self.save_dir, 'dicriminator_appearance'),
                osp.join(self.save_dir, 'generator_face'),
                osp.join(self.save_dir, 'dicriminator_face')
            ]
            for _ in joint_save_dirs:
                if not osp.exists(_):
                    os.makedirs(_)
            torch.save(self.generator_parsing.state_dict(), model_G_parsing)
            torch.save(self.generator_appearance.state_dict(),
                       model_G_appearance)
            torch.save(self.generator_face.state_dict(), model_G_face)
            torch.save(self.discriminator_appearance.state_dict(),
                       model_D_appearance)

    # Print the logs while training
    def print_current_errors(self, opt, epoch, i, iteration):

        ######################################
        # Part 1 GMM Print Logs
        ######################################

        if opt.train_mode == 'gmm':

            errors = {'loss_L1': self.loss.item()}

            for key in errors:
                self.writer.add_scalar('Loss/GMM/loss_L1/' + str(key),
                                       errors[key], iteration)
        ######################################
        # Part 2 PARSING Print Logs
        ######################################
        if opt.train_mode == 'parsing':

            errors = {
                'loss_G': self.loss_G.item(),
                'loss_G_GAN': self.loss_G_GAN.item(),
                'loss_G_BCE': self.loss_G_BCE.item(),
                'loss_D': self.loss_D.item(),
                'loss_D_real': self.loss_D_real.item(),
                'loss_D_fake': self.loss_D_fake.item()
            }

            for key in errors:
                self.writer.add_scalar('Loss/PARSING/' + str(key), errors[key],
                                       iteration)

        ######################################
        # Part 3 APPEARANCE Print Logs
        ######################################
        if opt.train_mode == 'appearance':

            if opt.joint_all and opt.joint_parse_loss:
                errors = {
                    'loss_G': self.loss_G.item(),
                    'loss_G_GAN': self.loss_G_GAN.item(),
                    'loss_G_vgg': self.loss_G_vgg.item(),
                    'loss_G_mask': self.loss_G_mask.item(),
                    'loss_G_L1': self.loss_G_L1.item(),
                    'loss_D': self.loss_D.item(),
                    'loss_D_real': self.loss_D_real.item(),
                    'loss_D_fake': self.loss_D_real.item(),
                    'loss_G_parsing': self.loss_G_parsing.item()
                }
                for key in errors:
                    self.writer.add_scalar('Loss/JOINTALL/' + str(key),
                                           errors[key], iteration)
            else:
                errors = {
                    'loss_G': self.loss_G.item(),
                    'loss_G_GAN': self.loss_G_GAN.item(),
                    'loss_G_vgg': self.loss_G_vgg.item(),
                    'loss_G_mask': self.loss_G_mask.item(),
                    'loss_G_L1': self.loss_G_L1.item(),
                    'loss_D': self.loss_D.item(),
                    'loss_D_real': self.loss_D_real.item(),
                    'loss_D_fake': self.loss_D_real.item(),
                    'loss_G_mask_tv': self.loss_G_mask_tv.item()
                }

                for key in errors:
                    self.writer.add_scalar('Loss/APPEARANCE/' + str(key),
                                           errors[key], iteration)

        ######################################
        # Part 4 FACE Print Logs
        ######################################
        if opt.train_mode == 'face':

            errors = {
                'loss_G': self.loss_G.item(),
                'loss_G_GAN': self.loss_G_GAN.item(),
                'loss_G_vgg': self.loss_G_vgg.item(),
                'loss_G_refine': self.loss_G_refine.item(),
                'loss_G_L1': self.loss_G_L1.item(),
                'loss_D': self.loss_D.item(),
                'loss_D_real': self.loss_D_real.item(),
                'loss_D_fake': self.loss_D_real.item()
            }

            for key in errors:
                self.writer.add_scalar('Loss/FACE/' + str(key), errors[key],
                                       iteration)

        # Print the errors
        t = self.t11 - self.t2
        message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
        for k, v in sorted(errors.items()):
            if v != 0:
                message += '%s: %.3f ' % (k, v)
        print(message)

        # Save logs
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)