Ejemplo n.º 1
0
class Swapper(BaseModel):

    PART_IDS = {
        'body': [1, 2, 3, 4, 5, 6, 7, 8, 9],
        'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    }

    def __init__(self, opt):
        super(Swapper, self).__init__(opt)
        self._name = 'Swapper'

        self._create_networks()

        # prefetch variables
        self.src_info = None
        self.tsf_info = None
        self.T = None
        self.T12 = None
        self.T21 = None
        self.grid = self.render.create_meshgrid(self._opt.image_size).cuda()
        self.part_fn = torch.tensor(mesh.create_mapping('par', self._opt.uv_mapping,
                                                        contain_bg=True, fill_back=False)).float().cuda()
        self.part_faces_dict = mesh.get_part_face_ids(part_type='par', fill_back=False)
        self.part_faces = list(self.part_faces_dict.values())

    def _create_networks(self):
        # 0. create generator
        self.generator = self._create_generator().cuda()

        # 0. create bgnet
        if self._opt.bg_model != 'ORIGINAL':
            self.bgnet = self._create_bgnet().cuda()
        else:
            self.bgnet = self.generator.bg_model

        # 2. create hmr
        self.hmr = self._create_hmr().cuda()

        # 3. create render
        self.render = SMPLRenderer(image_size=self._opt.image_size, tex_size=self._opt.tex_size,
                                   has_front=self._opt.front_warp, fill_back=False).cuda()
        # 4. pre-processor
        if self._opt.has_detector:
            self.detector = PersonMaskRCNNDetector(ks=self._opt.bg_ks, threshold=0.5, to_gpu=True)
        else:
            self.detector = None

    def _create_bgnet(self):
        net = NetworksFactory.get_by_name('deepfillv2', c_dim=4)
        self._load_params(net, self._opt.bg_model, need_module=False)
        net.eval()
        return net

    def _create_generator(self):
        net = NetworksFactory.get_by_name(self._opt.gen_name, bg_dim=4, src_dim=3+self._G_cond_nc,
                                          tsf_dim=3+self._G_cond_nc, repeat_num=self._opt.repeat_num)

        if self._opt.load_path:
            self._load_params(net, self._opt.load_path)
        elif self._opt.load_epoch > 0:
            self._load_network(net, 'G', self._opt.load_epoch)
        else:
            raise ValueError('load_path {} is empty and load_epoch {} is 0'.format(
                self._opt.load_path, self._opt.load_epoch))

        net.eval()
        return net

    def _create_hmr(self):
        hmr = HumanModelRecovery(self._opt.smpl_model)
        saved_data = torch.load(self._opt.hmr_model)
        hmr.load_state_dict(saved_data)
        hmr.eval()
        return hmr

    @staticmethod
    def visualize(*args, **kwargs):
        visualizer = args[0]
        if visualizer is not None:
            for key, value in kwargs.items():
                visualizer.vis_named_img(key, value)

    # TODO it dose not support mini-batch inputs currently.
    @torch.no_grad()
    def personalize(self, src_path, src_smpl=None, output_path='', visualizer=None):

        ori_img = cv_utils.read_cv2_img(src_path)

        # resize image and convert the color space from [0, 255] to [-1, 1]
        img = cv_utils.transform_img(ori_img, self._opt.image_size, transpose=True) * 2 - 1.0
        img = torch.tensor(img, dtype=torch.float32).cuda()[None, ...]

        if src_smpl is None:
            img_hmr = cv_utils.transform_img(ori_img, 224, transpose=True) * 2 - 1.0
            img_hmr = torch.tensor(img_hmr, dtype=torch.float32).cuda()[None, ...]
            src_smpl = self.hmr(img_hmr)
        else:
            src_smpl = torch.tensor(src_smpl, dtype=torch.float32).cuda()[None, ...]

        # source process, {'theta', 'cam', 'pose', 'shape', 'verts', 'j2d', 'j3d'}
        src_info = self.hmr.get_details(src_smpl)
        src_f2verts, src_fim, src_wim = self.render.render_fim_wim(src_info['cam'], src_info['verts'])
        # src_f2pts = src_f2verts[:, :, :, 0:2]
        src_info['fim'] = src_fim
        src_info['wim'] = src_wim
        src_info['cond'], _ = self.render.encode_fim(src_info['cam'], src_info['verts'], fim=src_fim, transpose=True)
        src_info['f2verts'] = src_f2verts
        src_info['p2verts'] = src_f2verts[:, :, :, 0:2]
        src_info['p2verts'][:, :, :, 1] *= -1

        if self._opt.only_vis:
            src_info['p2verts'] = self.render.get_vis_f2pts(src_info['p2verts'], src_fim)

        src_info['part'], _ = self.render.encode_fim(src_info['cam'], src_info['verts'],
                                                     fim=src_fim, transpose=True, map_fn=self.part_fn)
        # add image to source info
        src_info['img'] = img
        src_info['image'] = ori_img

        # 2. process the src inputs
        if self.detector is not None:
            bbox, body_mask = self.detector.inference(img[0])
            bg_mask = 1 - body_mask
        else:
            bg_mask = util.morph(src_info['cond'][:, -1:, :, :], ks=self._opt.bg_ks, mode='erode')
            body_mask = 1 - bg_mask

        if self._opt.bg_model != 'ORIGINAL':
            src_info['bg'] = self.bgnet(img, masks=body_mask, only_x=True)
        else:
            incomp_img = img * bg_mask
            bg_inputs = torch.cat([incomp_img, bg_mask], dim=1)
            img_bg = self.bgnet(bg_inputs)
            src_info['bg_inputs'] = bg_inputs
            # src_info['bg'] = img_bg
            src_info['bg'] = incomp_img + img_bg * body_mask

        ft_mask = 1 - util.morph(src_info['cond'][:, -1:, :, :], ks=self._opt.ft_ks, mode='erode')
        src_inputs = torch.cat([img * ft_mask, src_info['cond']], dim=1)

        src_info['feats'] = self.generator.encode_src(src_inputs)
        src_info['src_inputs'] = src_inputs

        src_info = src_info

        # if visualizer is not None:
        #     self.visualize(visualizer, src=img, bg=src_info['bg'])

        if output_path:
            cv_utils.save_cv2_img(src_info['image'], output_path, image_size=self._opt.image_size)

        return src_info

    def _extract_smpls(self, input_file):
        img = cv_utils.read_cv2_img(input_file)
        img = cv_utils.transform_img(img, image_size=224) * 2 - 1.0  # hmr receive [-1, 1]
        img = img.transpose((2, 0, 1))
        img = torch.FloatTensor(img).cuda()[None, ...]
        theta = self.hmr(img)[-1]

        return theta

    @torch.no_grad()
    def swap_smpl(self, src_cam, src_shape, tgt_smpl, preserve_scale=True):
        cam = tgt_smpl[:, 0:3].contiguous()
        pose = tgt_smpl[:, 3:75].contiguous()

        if preserve_scale:
            cam[:, 0] = src_cam[:, 0]
            cam[:, 1:] = (src_cam[:, 0] / cam[:, 0]) * cam[:, 1:] + src_cam[:, 1:]
            cam[:, 0] = src_cam[:, 0]
        else:
            cam[: 0] = src_cam[:, 0]

        tsf_smpl = torch.cat([cam, pose, src_shape], dim=1)

        return tsf_smpl

    @torch.no_grad()
    def swap_setup(self, src_path, tgt_path, src_smpl=None, tgt_smpl=None, output_dir=''):
        self.src_info = self.personalize(src_path, src_smpl)
        self.tsf_info = self.personalize(tgt_path, tgt_smpl)

    @torch.no_grad()
    def swap(self, src_info, tgt_info, target_part='body', visualizer=None):
        assert target_part in self.PART_IDS.keys()

        def merge_list(part_ids):
            faces = set()
            for i in part_ids:
                fs = set(self.part_faces[i])
                faces |= fs
            return list(faces)

        # get target selected face index map
        selected_ids = self.PART_IDS[target_part]
        left_ids = [i for i in self.PART_IDS['all'] if i not in selected_ids]

        src_part_mask = (torch.sum(src_info['part'][:, selected_ids, ...], dim=1) != 0).bool()
        src_left_mask = torch.sum(src_info['part'][:, left_ids, ...], dim=1).bool()

        # selected_faces = merge_list(selected_ids)
        left_faces = merge_list(left_ids)

        T11, T21 = self.calculate_trans(src_left_mask, left_faces)

        tsf21 = self.generator.transform(tgt_info['img'], T21)
        tsf11 = self.generator.transform(src_info['img'], T11)

        src_part_mask = src_part_mask[:, None, :, :].float()
        src_left_mask = src_left_mask[:, None, :, :].float()
        tsf_img = tsf21 * src_part_mask + tsf11 * src_left_mask

        tsf_inputs = torch.cat([tsf_img, src_info['cond']], dim=1)

        preds, tsf_mask = self.forward(tsf_inputs, tgt_info['feats'], T21, src_info['feats'], T11, src_info['bg'])

        if self._opt.front_warp:
            # preds = tsf11 * src_left_mask + (1 - src_left_mask) * preds
            preds = self.warp(preds, src_info['img'], src_info['fim'], tsf_mask)

        if visualizer is not None:
            self.visualize(visualizer, src_img=src_info['img'], tgt_img=tgt_info['img'], preds=preds)

        return preds

    # TODO it dose not support mini-batch inputs currently.
    def calculate_trans(self, src_left_mask, left_faces):
        # calculate T11
        T11 = self.grid.clone()
        T11[~src_left_mask[0]] = -2
        T11.unsqueeze_(0)

        # calculate T21
        tsf_f2p = self.tsf_info['p2verts'].clone()
        tsf_f2p[0, left_faces] = -2
        T21 = self.render.cal_bc_transform(tsf_f2p, self.src_info['fim'], self.src_info['wim'])
        T21.clamp_(-2, 2)
        return T11, T21

    def warp(self, preds, tsf, fim, fake_tsf_mask):
        front_mask = self.render.encode_front_fim(fim, transpose=True)
        preds = (1 - front_mask) * preds + tsf * front_mask * (1 - fake_tsf_mask)
        # preds = torch.clamp(preds + tsf * front_mask, -1, 1)
        return preds

    def forward(self, tsf_inputs, feats21, T21, feats11, T11, bg):
        with torch.no_grad():
            # generate fake images
            src_encoder_outs21, src_resnet_outs21 = feats21
            src_encoder_outs11, src_resnet_outs11 = feats11

            tsf_color, tsf_mask = self.generator.swap(tsf_inputs, src_encoder_outs21, src_encoder_outs11,
                                                      src_resnet_outs21, src_resnet_outs11, T21, T11)
            pred_imgs = tsf_mask * bg + (1 - tsf_mask) * tsf_color

        return pred_imgs, tsf_mask

    def post_personalize(self, out_dir, visualizer, verbose=True):
        from networks.networks import FaceLoss
        init_bg = torch.cat([self.src_info['bg'], self.tsf_info['bg']], dim=0)

        @torch.no_grad()
        def initialize(src_info, tsf_info):
            src_encoder_outs, src_resnet_outs = src_info['feats']
            src_f2p = src_info['p2verts']

            tsf_fim = tsf_info['fim']
            tsf_wim = tsf_info['wim']
            tsf_cond = tsf_info['cond']

            T = self.render.cal_bc_transform(src_f2p, tsf_fim, tsf_wim)
            tsf_img = F.grid_sample(src_info['img'], T)
            tsf_inputs = torch.cat([tsf_img, tsf_cond], dim=1)

            tsf_color, tsf_mask = self.generator.inference(
                src_encoder_outs, src_resnet_outs, tsf_inputs, T)

            preds = src_info['bg'] * tsf_mask + tsf_color * (1 - tsf_mask)

            if self._opt.front_warp:
                preds = self.warp(preds, tsf_img, tsf_fim, tsf_mask)

            return preds, T, tsf_inputs

        @torch.no_grad()
        def set_inputs(src_info, tsf_info):
            s2t_init_preds, s2t_T, s2t_tsf_inputs = initialize(src_info, tsf_info)
            t2s_init_preds, t2s_T, t2s_tsf_inputs = initialize(tsf_info, src_info)

            s2t_j2d = torch.cat([src_info['j2d'], tsf_info['j2d']], dim=0)
            t2s_j2d = torch.cat([tsf_info['j2d'], src_info['j2d']], dim=0)
            j2ds = torch.stack([s2t_j2d, t2s_j2d], dim=0)

            init_preds = torch.cat([s2t_init_preds, t2s_init_preds], dim=0)
            images = torch.cat([src_info['img'], tsf_info['img']], dim=0)
            T = torch.cat([s2t_T, t2s_T], dim=0)
            T_cycle = torch.cat([t2s_T, s2t_T], dim=0)
            tsf_inputs = torch.cat([s2t_tsf_inputs, t2s_tsf_inputs], dim=0)
            src_fim = torch.cat([src_info['fim'], tsf_info['fim']], dim=0)
            tsf_fim = torch.cat([tsf_info['fim'], src_info['fim']], dim=0)

            s2t_inputs = src_info['src_inputs']
            t2s_inputs = tsf_info['src_inputs']

            src_inputs = torch.cat([s2t_inputs, t2s_inputs], dim=0)

            src_mask = util.morph(src_inputs[:, -1:, ], ks=self._opt.ft_ks, mode='erode')
            tsf_mask = util.morph(tsf_inputs[:, -1:, ], ks=self._opt.ft_ks, mode='erode')

            pseudo_masks = torch.cat([src_mask, tsf_mask], dim=0)

            return src_fim, tsf_fim, j2ds, T, T_cycle, src_inputs, tsf_inputs, images, init_preds, pseudo_masks

        def set_cycle_inputs(fake_tsf_imgs, src_inputs, tsf_inputs, T_cycle):
            # set cycle bg inputs
            tsf_bg_mask = tsf_inputs[:, -1:, ...]

            # set cycle src inputs
            cycle_src_inputs = torch.cat([fake_tsf_imgs * tsf_bg_mask, tsf_inputs[:, 3:]], dim=1)

            # set cycle tsf inputs
            cycle_tsf_img = F.grid_sample(fake_tsf_imgs, T_cycle)
            cycle_tsf_inputs = torch.cat([cycle_tsf_img, src_inputs[:, 3:]], dim=1)

            return cycle_src_inputs, cycle_tsf_inputs

        def inference(src_inputs, tsf_inputs, T, T_cycle, src_fim, tsf_fim):
            fake_src_color, fake_src_mask, fake_tsf_color, fake_tsf_mask = \
                self.generator.infer_front(src_inputs, tsf_inputs, T=T)

            fake_src_imgs = fake_src_mask * init_bg + (1 - fake_src_mask) * fake_src_color
            fake_tsf_imgs = fake_tsf_mask * init_bg + (1 - fake_tsf_mask) * fake_tsf_color

            if self._opt.front_warp:
                fake_tsf_imgs = self.warp(fake_tsf_imgs, tsf_inputs[:, 0:3], tsf_fim, fake_tsf_mask)

            cycle_src_inputs, cycle_tsf_inputs = set_cycle_inputs(
                fake_tsf_imgs, src_inputs, tsf_inputs, T_cycle)

            cycle_src_color, cycle_src_mask, cycle_tsf_color, cycle_tsf_mask = \
                self.generator.infer_front(cycle_src_inputs, cycle_tsf_inputs, T=T_cycle)

            cycle_src_imgs = cycle_src_mask * init_bg + (1 - cycle_src_mask) * cycle_src_color
            cycle_tsf_imgs = cycle_tsf_mask * init_bg + (1 - cycle_tsf_mask) * cycle_tsf_color

            if self._opt.front_warp:
                cycle_tsf_imgs = self.warp(cycle_tsf_imgs, src_inputs[:, 0:3], src_fim, fake_src_mask)

            return fake_src_imgs, fake_tsf_imgs, cycle_src_imgs, cycle_tsf_imgs, fake_src_mask, fake_tsf_mask, cycle_tsf_inputs

        def create_criterion():
            face_criterion = FaceLoss(pretrained_path=self._opt.face_model).cuda()
            idt_criterion = torch.nn.L1Loss()
            mask_criterion = torch.nn.BCELoss()

            return face_criterion, idt_criterion, mask_criterion

        def print_losses(*args, **kwargs):

            print('step = {}'.format(kwargs['step']))
            for key, value in kwargs.items():
                if key == 'step':
                    continue
                print('\t{}, {:.6f}'.format(key, value.item()))

        def update_learning_rate(optimizer, current_lr, init_lr, final_lr, nepochs_decay):
            # updated learning rate G
            lr_decay = (init_lr - final_lr) / nepochs_decay
            current_lr -= lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr
            # print('update G learning rate: %f -> %f' % (current_lr + lr_decay, current_lr))
            return current_lr

        init_lr = 0.0002
        cur_lr = init_lr
        final_lr = 0.00001
        fix_iters = 25
        total_iters = 50
        optimizer = torch.optim.Adam(self.generator.parameters(), lr=init_lr, betas=(0.5, 0.999))
        face_cri, idt_cri, msk_cri = create_criterion()

        # set up inputs
        src_fim, tsf_fim, j2ds, T, T_cycle, src_inputs, tsf_inputs, \
        src_imgs, init_preds, pseudo_masks = set_inputs(
            src_info=self.src_info, tsf_info=self.tsf_info
        )

        logger = tqdm(range(total_iters))
        for step in logger:

            fake_src_imgs, fake_tsf_imgs, cycle_src_imgs, cycle_tsf_imgs, \
            fake_src_mask, fake_tsf_mask, cycle_tsf_inputs = inference(src_inputs, tsf_inputs,
                                                                       T, T_cycle, src_fim, tsf_fim)

            # cycle reconstruction loss
            cycle_loss = idt_cri(src_imgs, fake_src_imgs) + idt_cri(src_imgs, cycle_tsf_imgs)

            # structure loss
            bg_mask = src_inputs[:, -1:]
            body_mask = 1.0 - bg_mask
            str_src_imgs = src_imgs * body_mask
            cycle_warp_imgs = cycle_tsf_inputs[:, 0:3]
            # back_head_mask = 1 - self.render.encode_front_fim(tsf_fim, transpose=True, front_fn=False)
            # struct_loss = idt_cri(init_preds, fake_tsf_imgs) + \
            #               2 * idt_cri(str_src_imgs * back_head_mask, cycle_warp_imgs * back_head_mask)

            struct_loss = idt_cri(init_preds, fake_tsf_imgs) + \
                          2 * idt_cri(str_src_imgs, cycle_warp_imgs)

            # fid_loss = face_cri(src_imgs, cycle_tsf_imgs, kps1=j2ds[:, 0], kps2=j2ds[:, 0]) + \
            #            face_cri(init_preds, fake_tsf_imgs, kps1=j2ds[:, 1], kps2=j2ds[:, 1])

            fid_loss = face_cri(src_imgs, cycle_tsf_imgs, kps1=j2ds[:, 0], kps2=j2ds[:, 0]) + \
                       face_cri(tsf_inputs[:, 0:3], fake_tsf_imgs, kps1=j2ds[:, 1], kps2=j2ds[:, 1])

            # mask loss
            # mask_loss = msk_cri(fake_tsf_mask, tsf_inputs[:, -1:]) + msk_cri(fake_src_mask, src_inputs[:, -1:])
            mask_loss = msk_cri(torch.cat([fake_src_mask, fake_tsf_mask], dim=0), pseudo_masks)

            loss = 10 * cycle_loss + 10 * struct_loss + fid_loss + 5 * mask_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # print_losses(step=step, total=loss, cyc=cycle_loss,
            #              str=struct_loss, fid=fid_loss, msk=mask_loss)

            if verbose:
                logger.set_description(
                    (
                        f'step: {step}; '
                        f'total: {loss.item():.6f}; cyc: {cycle_loss.item():.6f}; '
                        f'str: {struct_loss.item():.6f}; fid: {fid_loss.item():.6f}; '
                        f'msk: {mask_loss.item():.6f}'
                    )
                )

            if step % 10 == 0:
                self.visualize(visualizer, input_imgs=src_imgs, tsf_imgs=fake_tsf_imgs,
                               cyc_imgs=cycle_tsf_imgs, fake_tsf_mask=fake_tsf_mask,
                               init_preds=init_preds,
                               str_src_imgs=str_src_imgs,
                               cycle_warp_imgs=cycle_warp_imgs)

            if step > fix_iters:
                cur_lr = update_learning_rate(optimizer, cur_lr, init_lr, final_lr, fix_iters)

        self.generator.eval()
Ejemplo n.º 2
0
class Imitator(BaseModel):
    def __init__(self, opt):
        super(Imitator, self).__init__(opt)
        self._name = 'Imitator'

        self._create_networks()

        # prefetch variables
        self.src_info = None
        self.tsf_info = None
        self.first_cam = None

        self.t = 0

    def _create_networks(self):
        # 0. create generator
        self.generator = self._create_generator().cuda()

        # 0. create bgnet
        if self._opt.bg_model != 'ORIGINAL':
            self.bgnet = self._create_bgnet().cuda()
        else:
            self.bgnet = self.generator.bg_model

        # 2. create hmr
        self.hmr = self._create_hmr().cuda()

        # 3. create render
        self.render = SMPLRenderer(image_size=self._opt.image_size,
                                   tex_size=self._opt.tex_size,
                                   has_front=self._opt.front_warp,
                                   fill_back=False).cuda()
        # 4. pre-processor
        if self._opt.has_detector:
            self.detector = PersonMaskRCNNDetector(ks=self._opt.bg_ks,
                                                   threshold=0.5,
                                                   to_gpu=True)
        else:
            self.detector = None

    def _create_bgnet(self):
        net = NetworksFactory.get_by_name('deepfillv2', c_dim=4)
        self._load_params(net, self._opt.bg_model, need_module=False)
        net.eval()
        return net

    def _create_generator(self):
        net = NetworksFactory.get_by_name(self._opt.gen_name,
                                          bg_dim=4,
                                          src_dim=3 + self._G_cond_nc,
                                          tsf_dim=3 + self._G_cond_nc,
                                          repeat_num=self._opt.repeat_num)

        if self._opt.load_path:
            self._load_params(net, self._opt.load_path)
        elif self._opt.load_epoch > 0:
            self._load_network(net, 'G', self._opt.load_epoch)
        else:
            raise ValueError(
                'load_path {} is empty and load_epoch {} is 0'.format(
                    self._opt.load_path, self._opt.load_epoch))

        net.eval()
        return net

    def _create_hmr(self):
        hmr = HumanModelRecovery(self._opt.smpl_model)
        saved_data = torch.load(self._opt.hmr_model)
        hmr.load_state_dict(saved_data)
        hmr.eval()
        return hmr

    def visualize(self, *args, **kwargs):
        visualizer = args[0]
        if visualizer is not None:
            for key, value in kwargs.items():
                visualizer.vis_named_img(key, value)

    @torch.no_grad()
    def personalize(self,
                    ori_img,
                    src_smpl=None,
                    output_path='',
                    visualizer=None):

        # ori_img = cv_utils.read_cv2_img(src_path)

        # resize image and convert the color space from [0, 255] to [-1, 1]
        img = cv_utils.transform_img(
            ori_img, self._opt.image_size, transpose=True) * 2 - 1.0
        img = torch.tensor(img, dtype=torch.float32).cuda()[None, ...]

        if src_smpl is None:
            img_hmr = cv_utils.transform_img(ori_img, 224,
                                             transpose=True) * 2 - 1.0
            img_hmr = torch.tensor(img_hmr, dtype=torch.float32).cuda()[None,
                                                                        ...]
            src_smpl = self.hmr(img_hmr)
        else:
            src_smpl = torch.tensor(src_smpl, dtype=torch.float32).cuda()[None,
                                                                          ...]

        # source process, {'theta', 'cam', 'pose', 'shape', 'verts', 'j2d', 'j3d'}
        src_info = self.hmr.get_details(src_smpl)
        src_f2verts, src_fim, src_wim = self.render.render_fim_wim(
            src_info['cam'], src_info['verts'])
        # src_f2pts = src_f2verts[:, :, :, 0:2]
        src_info['fim'] = src_fim
        src_info['wim'] = src_wim
        src_info['cond'], _ = self.render.encode_fim(src_info['cam'],
                                                     src_info['verts'],
                                                     fim=src_fim,
                                                     transpose=True)
        src_info['f2verts'] = src_f2verts
        src_info['p2verts'] = src_f2verts[:, :, :, 0:2]
        src_info['p2verts'][:, :, :, 1] *= -1

        if self._opt.only_vis:
            src_info['p2verts'] = self.render.get_vis_f2pts(
                src_info['p2verts'], src_fim)
        # add image to source info
        src_info['img'] = img
        src_info['image'] = ori_img

        # 2. process the src inputs
        if self.detector is not None:
            bbox, body_mask = self.detector.inference(img[0])
            bg_mask = 1 - body_mask
        else:
            # bg is 1, ft is 0
            bg_mask = util.morph(src_info['cond'][:, -1:, :, :],
                                 ks=self._opt.bg_ks,
                                 mode='erode')
            body_mask = 1 - bg_mask

        if self._opt.bg_model != 'ORIGINAL':
            src_info['bg'] = self.bgnet(img, masks=body_mask, only_x=True)
        else:
            incomp_img = img * bg_mask
            bg_inputs = torch.cat([incomp_img, bg_mask], dim=1)
            img_bg = self.bgnet(bg_inputs)
            # src_info['bg'] = bg_inputs[:, 0:3] + img_bg * bg_inputs[:, -1:]
            src_info['bg'] = img_bg

        ft_mask = 1 - util.morph(
            src_info['cond'][:, -1:, :, :], ks=self._opt.ft_ks, mode='erode')
        src_inputs = torch.cat([img * ft_mask, src_info['cond']], dim=1)

        src_info['feats'] = self.generator.encode_src(src_inputs)

        self.src_info = src_info

        if visualizer is not None:
            visualizer.vis_named_img('src', img)
            visualizer.vis_named_img('bg', src_info['bg'])

        if output_path:
            cv_utils.save_cv2_img(src_info['image'],
                                  output_path,
                                  image_size=self._opt.image_size)

    @torch.no_grad()
    def _extract_smpls(self, input_file):
        img = cv_utils.read_cv2_img(input_file)
        img = cv_utils.transform_img(
            img, image_size=224) * 2 - 1.0  # hmr receive [-1, 1]
        img = img.transpose((2, 0, 1))
        img = torch.tensor(img, dtype=torch.float32).cuda()[None, ...]
        theta = self.hmr(img)[-1]

        return theta

    @torch.no_grad()
    def inference(self,
                  tgt_imgs,
                  tgt_smpls=None,
                  cam_strategy='smooth',
                  output_dir='',
                  visualizer=None,
                  verbose=True):

        length = len(tgt_imgs)

        outputs = []
        process_bar = tqdm(range(length)) if verbose else range(length)

        for t in process_bar:
            tgt_img = tgt_imgs[t]
            tgt_smpl = tgt_smpls[t] if tgt_smpls is not None else None

            tsf_inputs = self.transfer_params(tgt_img, tgt_smpl, cam_strategy)
            preds = self.forward(tsf_inputs, self.tsf_info['T'])

            preds = preds[0].permute(1, 2, 0)
            preds = preds.cpu().numpy()
            outputs.append(preds)

        return outputs

    @torch.no_grad()
    def inference_by_smpls(self,
                           tgt_smpls,
                           cam_strategy='smooth',
                           output_dir='',
                           visualizer=None):
        length = len(tgt_smpls)

        outputs = []
        for t in tqdm(range(length)):
            tgt_smpl = tgt_smpls[t] if tgt_smpls is not None else None

            tsf_inputs = self.transfer_params_by_smpl(tgt_smpl, cam_strategy)
            preds = self.forward(tsf_inputs, self.tsf_info['T'])

            if visualizer is not None:
                gt = cv_utils.transform_img(self.tsf_info['image'],
                                            image_size=self._opt.image_size,
                                            transpose=True)
                visualizer.vis_named_img('pred_' + cam_strategy, preds)
                visualizer.vis_named_img('gt',
                                         gt[None, ...],
                                         denormalize=False)

            preds = preds[0].permute(1, 2, 0)
            preds = preds.cpu().numpy()
            outputs.append(preds)

            if output_dir:
                cv_utils.save_cv2_img(preds,
                                      os.path.join(output_dir,
                                                   'pred_%.8d.jpg' % t),
                                      normalize=True)

        return outputs

    def swap_smpl(self, src_cam, src_shape, tgt_smpl, cam_strategy='smooth'):
        tgt_cam = tgt_smpl[:, 0:3].contiguous()
        pose = tgt_smpl[:, 3:75].contiguous()

        # TODO, need more tricky ways
        if cam_strategy == 'smooth':

            cam = src_cam.clone()
            delta_xy = tgt_cam[:, 1:] - self.first_cam[:, 1:]
            cam[:, 1:] += delta_xy

        elif cam_strategy == 'source':
            cam = src_cam
        else:
            cam = tgt_cam

        tsf_smpl = torch.cat([cam, pose, src_shape], dim=1)

        return tsf_smpl

    def transfer_params_by_smpl(self, tgt_smpl, cam_strategy='smooth'):
        # get source info
        src_info = self.src_info

        if isinstance(tgt_smpl, np.ndarray):
            tgt_smpl = torch.tensor(tgt_smpl).float().cuda()[None, ...]

        if self.t == 0 and cam_strategy == 'smooth':
            self.first_cam = tgt_smpl[:, 0:3].clone()

        self.t += 1

        # get transfer smpl
        tsf_smpl = self.swap_smpl(src_info['cam'],
                                  src_info['shape'],
                                  tgt_smpl,
                                  cam_strategy=cam_strategy)
        # transfer process, {'theta', 'cam', 'pose', 'shape', 'verts', 'j2d', 'j3d'}
        tsf_info = self.hmr.get_details(tsf_smpl)

        tsf_f2verts, tsf_fim, tsf_wim = self.render.render_fim_wim(
            tsf_info['cam'], tsf_info['verts'])
        # src_f2pts = src_f2verts[:, :, :, 0:2]
        tsf_info['fim'] = tsf_fim
        tsf_info['wim'] = tsf_wim
        tsf_info['cond'], _ = self.render.encode_fim(tsf_info['cam'],
                                                     tsf_info['verts'],
                                                     fim=tsf_fim,
                                                     transpose=True)
        # tsf_info['sil'] = util.morph((tsf_fim != -1).float(), ks=self._opt.ft_ks, mode='dilate')

        T = self.render.cal_bc_transform(src_info['p2verts'], tsf_fim, tsf_wim)
        tsf_img = F.grid_sample(src_info['img'], T)
        tsf_inputs = torch.cat([tsf_img, tsf_info['cond']], dim=1)

        # add target image to tsf info
        tsf_info['tsf_img'] = tsf_img
        tsf_info['T'] = T

        self.tsf_info = tsf_info

        return tsf_inputs

    def transfer_params(self, ori_img, tgt_smpl=None, cam_strategy='smooth'):
        # ori_img = cv_utils.read_cv2_img(tgt_path)
        if tgt_smpl is None:
            img_hmr = cv_utils.transform_img(ori_img, 224,
                                             transpose=True) * 2 - 1.0
            img_hmr = torch.tensor(img_hmr, dtype=torch.float32).cuda()[None,
                                                                        ...]
            tgt_smpl = self.hmr(img_hmr)
        else:
            if isinstance(tgt_smpl, np.ndarray):
                tgt_smpl = torch.tensor(tgt_smpl,
                                        dtype=torch.float32).cuda()[None, ...]

        tsf_inputs = self.transfer_params_by_smpl(tgt_smpl=tgt_smpl,
                                                  cam_strategy=cam_strategy)
        self.tsf_info['image'] = ori_img

        return tsf_inputs

    # @torch.no_grad()
    # def transfer_params(self, tgt_path, tgt_smpl=None, cam_strategy='smooth', t=0):
    #     # get source info
    #     src_info = self.src_info
    #
    #     ori_img = cv_utils.read_cv2_img(tgt_path)
    #     if tgt_smpl is None:
    #         img_hmr = cv_utils.transform_img(ori_img, 224, transpose=True) * 2 - 1.0
    #         img_hmr = torch.tensor(img_hmr, dtype=torch.float32).cuda()[None, ...]
    #         tgt_smpl = self.hmr(img_hmr)
    #     else:
    #         tgt_smpl = torch.tensor(tgt_smpl, dtype=torch.float32).cuda()[None, ...]
    #
    #     if t == 0 and cam_strategy == 'smooth':
    #         self.first_cam = tgt_smpl[:, 0:3].clone()
    #
    #     # get transfer smpl
    #     tsf_smpl = self.swap_smpl(src_info['cam'], src_info['shape'], tgt_smpl, cam_strategy=cam_strategy)
    #     # transfer process, {'theta', 'cam', 'pose', 'shape', 'verts', 'j2d', 'j3d'}
    #     tsf_info = self.hmr.get_details(tsf_smpl)
    #
    #     tsf_f2verts, tsf_fim, tsf_wim = self.render.render_fim_wim(tsf_info['cam'], tsf_info['verts'])
    #     # src_f2pts = src_f2verts[:, :, :, 0:2]
    #     tsf_info['fim'] = tsf_fim
    #     tsf_info['wim'] = tsf_wim
    #     tsf_info['cond'], _ = self.render.encode_fim(tsf_info['cam'], tsf_info['verts'], fim=tsf_fim, transpose=True)
    #     # tsf_info['sil'] = util.morph((tsf_fim != -1).float(), ks=self._opt.ft_ks, mode='dilate')
    #
    #     T = self.render.cal_bc_transform(src_info['p2verts'], tsf_fim, tsf_wim)
    #     tsf_img = F.grid_sample(src_info['img'], T)
    #     tsf_inputs = torch.cat([tsf_img, tsf_info['cond']], dim=1)
    #
    #     # add target image to tsf info
    #     tsf_info['tsf_img'] = tsf_img
    #     tsf_info['image'] = ori_img
    #     tsf_info['T'] = T
    #
    #     self.tsf_info = tsf_info
    #
    #     return tsf_inputs

    def forward(self, tsf_inputs, T):
        bg_img = self.src_info['bg']
        src_encoder_outs, src_resnet_outs = self.src_info['feats']

        tsf_color, tsf_mask = self.generator.inference(src_encoder_outs,
                                                       src_resnet_outs,
                                                       tsf_inputs, T)
        pred_imgs = tsf_mask * bg_img + (1 - tsf_mask) * tsf_color

        if self._opt.front_warp:
            pred_imgs = self.warp_front(pred_imgs, tsf_mask)

        return pred_imgs

    def warp_front(self, preds, mask):
        front_mask = self.render.encode_front_fim(self.tsf_info['fim'],
                                                  transpose=True,
                                                  front_fn=True)
        preds = (1 - front_mask
                 ) * preds + self.tsf_info['tsf_img'] * front_mask * (1 - mask)
        # preds = torch.clamp(preds + self.tsf_info['tsf_img'] * front_mask, -1, 1)
        return preds

    def post_personalize(self, out_dir, data_loader, visualizer, verbose=True):
        from networks.networks import FaceLoss

        bg_inpaint = self.src_info['bg']

        @torch.no_grad()
        def set_gen_inputs(sample):
            j2ds = sample['j2d'].cuda()  # (N, 4)
            T = sample['T'].cuda()  # (N, h, w, 2)
            T_cycle = sample['T_cycle'].cuda()  # (N, h, w, 2)
            src_inputs = sample['src_inputs'].cuda()  # (N, 6, h, w)
            tsf_inputs = sample['tsf_inputs'].cuda()  # (N, 6, h, w)
            src_fim = sample['src_fim'].cuda()
            tsf_fim = sample['tsf_fim'].cuda()
            init_preds = sample['preds'].cuda()
            images = sample['images']
            images = torch.cat([images[:, 0, ...], images[:, 1, ...]],
                               dim=0).cuda()  # (2N, 3, h, w)
            pseudo_masks = sample['pseudo_masks']
            pseudo_masks = torch.cat(
                [pseudo_masks[:, 0, ...], pseudo_masks[:, 1, ...]],
                dim=0).cuda()  # (2N, 1, h, w)

            return src_fim, tsf_fim, j2ds, T, T_cycle, \
                   src_inputs, tsf_inputs, images, init_preds, pseudo_masks

        def set_cycle_inputs(fake_tsf_imgs, src_inputs, tsf_inputs, T_cycle):
            # set cycle src inputs
            cycle_src_inputs = torch.cat(
                [fake_tsf_imgs * tsf_inputs[:, -1:, ...], tsf_inputs[:, 3:]],
                dim=1)

            # set cycle tsf inputs
            cycle_tsf_img = F.grid_sample(fake_tsf_imgs, T_cycle)
            cycle_tsf_inputs = torch.cat([cycle_tsf_img, src_inputs[:, 3:]],
                                         dim=1)

            return cycle_src_inputs, cycle_tsf_inputs

        def warp(preds, tsf, fim, fake_tsf_mask):
            front_mask = self.render.encode_front_fim(fim, transpose=True)
            preds = (1 - front_mask) * preds + tsf * front_mask * (
                1 - fake_tsf_mask)
            # preds = torch.clamp(preds + tsf * front_mask, -1, 1)
            return preds

        def inference(src_inputs, tsf_inputs, T, T_cycle, src_fim, tsf_fim):
            fake_src_color, fake_src_mask, fake_tsf_color, fake_tsf_mask = \
                self.generator.infer_front(src_inputs, tsf_inputs, T=T)

            fake_src_imgs = fake_src_mask * bg_inpaint + (
                1 - fake_src_mask) * fake_src_color
            fake_tsf_imgs = fake_tsf_mask * bg_inpaint + (
                1 - fake_tsf_mask) * fake_tsf_color

            if self._opt.front_warp:
                fake_tsf_imgs = warp(fake_tsf_imgs, tsf_inputs[:, 0:3],
                                     tsf_fim, fake_tsf_mask)

            cycle_src_inputs, cycle_tsf_inputs = set_cycle_inputs(
                fake_tsf_imgs, src_inputs, tsf_inputs, T_cycle)

            cycle_src_color, cycle_src_mask, cycle_tsf_color, cycle_tsf_mask = \
                self.generator.infer_front(cycle_src_inputs, cycle_tsf_inputs, T=T_cycle)

            cycle_src_imgs = cycle_src_mask * bg_inpaint + (
                1 - cycle_src_mask) * cycle_src_color
            cycle_tsf_imgs = cycle_tsf_mask * bg_inpaint + (
                1 - cycle_tsf_mask) * cycle_tsf_color

            if self._opt.front_warp:
                cycle_tsf_imgs = warp(cycle_tsf_imgs, src_inputs[:, 0:3],
                                      src_fim, fake_src_mask)

            return fake_src_imgs, fake_tsf_imgs, cycle_src_imgs, cycle_tsf_imgs, fake_src_mask, fake_tsf_mask

        def create_criterion():
            face_criterion = FaceLoss(
                pretrained_path=self._opt.face_model).cuda()
            idt_criterion = torch.nn.L1Loss()
            mask_criterion = torch.nn.BCELoss()

            return face_criterion, idt_criterion, mask_criterion

        init_lr = 0.0002
        nodecay_epochs = 5
        optimizer = torch.optim.Adam(self.generator.parameters(),
                                     lr=init_lr,
                                     betas=(0.5, 0.999))
        face_cri, idt_cri, msk_cri = create_criterion()

        step = 0
        logger = tqdm(range(nodecay_epochs))
        for epoch in logger:
            for i, sample in enumerate(data_loader):
                src_fim, tsf_fim, j2ds, T, T_cycle, src_inputs, tsf_inputs, \
                images, init_preds, pseudo_masks = set_gen_inputs(sample)

                # print(bg_inputs.shape, src_inputs.shape, tsf_inputs.shape)
                bs = tsf_inputs.shape[0]
                src_imgs = images[0:bs]
                fake_src_imgs, fake_tsf_imgs, cycle_src_imgs, cycle_tsf_imgs, fake_src_mask, fake_tsf_mask = inference(
                    src_inputs, tsf_inputs, T, T_cycle, src_fim, tsf_fim)

                # cycle reconstruction loss
                cycle_loss = idt_cri(src_imgs, fake_src_imgs) + idt_cri(
                    src_imgs, cycle_tsf_imgs)

                # structure loss
                bg_mask = src_inputs[:, -1:]
                body_mask = 1 - bg_mask
                str_src_imgs = src_imgs * body_mask
                cycle_warp_imgs = F.grid_sample(fake_tsf_imgs, T_cycle)
                back_head_mask = 1 - self.render.encode_front_fim(
                    tsf_fim, transpose=True, front_fn=False)
                struct_loss = idt_cri(init_preds, fake_tsf_imgs) + \
                              2 * idt_cri(str_src_imgs * back_head_mask, cycle_warp_imgs * back_head_mask)

                fid_loss = face_cri(src_imgs, cycle_tsf_imgs, kps1=j2ds[:, 0], kps2=j2ds[:, 0]) + \
                           face_cri(init_preds, fake_tsf_imgs, kps1=j2ds[:, 1], kps2=j2ds[:, 1])

                # mask loss
                # mask_loss = msk_cri(fake_tsf_mask, tsf_inputs[:, -1:]) + msk_cri(fake_src_mask, src_inputs[:, -1:])
                mask_loss = msk_cri(
                    torch.cat([fake_src_mask, fake_tsf_mask], dim=0),
                    pseudo_masks)

                loss = 10 * cycle_loss + 10 * struct_loss + fid_loss + 5 * mask_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if verbose:
                    logger.set_description((
                        f'epoch: {epoch + 1}; step: {step}; '
                        f'total: {loss.item():.6f}; cyc: {cycle_loss.item():.6f}; '
                        f'str: {struct_loss.item():.6f}; fid: {fid_loss.item():.6f}; '
                        f'msk: {mask_loss.item():.6f}'))

                if verbose and step % 5 == 0:
                    self.visualize(visualizer,
                                   input_imgs=images,
                                   tsf_imgs=fake_tsf_imgs,
                                   cyc_imgs=cycle_tsf_imgs)

                step += 1

        self.generator.eval()