예제 #1
0
    def personalize(self, src_path, src_smpl=None):

        with torch.no_grad():
            ori_img = cv_utils.read_cv2_img(src_path)

            # resize image and convert the color space from [0, 255] to [-1, 1]
            img = cv_utils.transform_img(ori_img, self._opt.image_size, transpose=True) * 2 - 1.0
            img = torch.FloatTensor(img).cuda()[None, ...]

            if src_smpl is None:
                img_hmr = cv_utils.transform_img(ori_img, 224, transpose=True) * 2 - 1.0
                img_hmr = torch.FloatTensor(img_hmr).cuda()[None, ...]
                src_smpl = self.hmr(img_hmr)[-1]
            else:
                src_smpl = to_tensor(src_smpl).cuda()[None, ...]

            # source process, {'theta', 'cam', 'pose', 'shape', 'verts', 'j2d', 'j3d'}
            src_info = self.hmr.get_details(src_smpl)

            # add source bary-center points
            src_info['bc_f2pts'] = self.get_src_bc_f2pts(src_info['cam'], src_info['verts'])

            # add image to source info
            src_info['image'] = img

            # add texture into source info
            _, src_info['tex'] = self.render.forward(src_info['cam'], src_info['verts'],
                                                     img, is_uv_sampler=False, reverse_yz=True, get_fim=False)

            # add pose condition and face index map into source info
            src_info['cond'], src_info['fim'] = self.render.encode_fim(src_info['cam'],
                                                                       src_info['verts'], transpose=True)

            # add part condition into source info
            src_info['part'] = self.render.encode_front_fim(src_info['fim'], transpose=True)

            # bg input and inpaiting background
            src_bg_mask = self.morph(src_info['cond'][:, -1:, :, :], ks=15, mode='erode')
            bg_inputs = torch.cat([img * src_bg_mask, src_bg_mask], dim=1)
            src_info['bg'] = self.model.bg_model(bg_inputs)
            #
            # source identity
            src_crop_mask = self.morph(src_info['cond'][:, -1:, :, :], ks=3, mode='erode')
            src_inputs = torch.cat([img * (1 - src_crop_mask), src_info['cond']], dim=1)
            src_info['feats'] = self.model.src_model.inference(src_inputs)
            #
            # self.src_info = src_info

            return src_info
예제 #2
0
    def inference_by_smpls(self,
                           tgt_smpls,
                           cam_strategy='smooth',
                           output_dir='',
                           visualizer=None):
        length = len(tgt_smpls)

        outputs = []
        for t in tqdm(range(length)):
            tgt_smpl = tgt_smpls[t] if tgt_smpls is not None else None

            tsf_inputs = self.transfer_params_by_smpl(tgt_smpl, cam_strategy)
            preds = self.forward(tsf_inputs, self.tsf_info['T'])

            if visualizer is not None:
                gt = cv_utils.transform_img(self.tsf_info['image'],
                                            image_size=self._opt.image_size,
                                            transpose=True)
                visualizer.vis_named_img('pred_' + cam_strategy, preds)
                visualizer.vis_named_img('gt',
                                         gt[None, ...],
                                         denormalize=False)

            preds = preds[0].permute(1, 2, 0)
            preds = preds.cpu().numpy()
            outputs.append(preds)

            if output_dir:
                cv_utils.save_cv2_img(preds,
                                      os.path.join(output_dir,
                                                   'pred_%.8d.jpg' % t),
                                      normalize=True)

        return outputs
예제 #3
0
    def inference(self, tgt_paths, tgt_smpls=None, cam_strategy='smooth', output_dir='', visualizer=None, verbose=True):
        length = len(tgt_paths)

        outputs = []

        process_bar = tqdm(range(length)) if verbose else range(length)

        for t in process_bar:
            tgt_path = tgt_paths[t]
            tgt_smpl = tgt_smpls[t] if tgt_smpls is not None else None

            tsf_inputs = self.transfer_params(tgt_path, tgt_smpl, cam_strategy, t=t)
            preds = self.forward(tsf_inputs, self.tsf_info['T'])

            if visualizer is not None:
                gt = cv_utils.transform_img(self.tsf_info['image'], image_size=self._opt.image_size, transpose=True)
                visualizer.vis_named_img('pred_' + cam_strategy, preds)
                visualizer.vis_named_img('gt', gt[None, ...], denormalize=False)

            preds = preds[0].permute(1, 2, 0)
            preds = preds.cpu().numpy()
            outputs.append(preds)

            if output_dir:
                filename = os.path.split(tgt_path)[-1]

                cv_utils.save_cv2_img(preds, os.path.join(output_dir, 'pred_' + filename), normalize=True)
                cv_utils.save_cv2_img(self.tsf_info['image'], os.path.join(output_dir, 'gt_' + filename),
                                      image_size=self._opt.image_size)

        return outputs
예제 #4
0
    def _extract_smpls(self, input_file):
        img = cv_utils.read_cv2_img(input_file)
        img = cv_utils.transform_img(img, image_size=224) * 2 - 1.0  # hmr receive [-1, 1]
        img = img.transpose((2, 0, 1))
        img = torch.FloatTensor(img).cuda()[None, ...]
        theta = self.hmr(img)[-1]

        return theta
예제 #5
0
    def transfer(self, tgt_path, tgt_smpl=None, cam_strategy='smooth', t=0, visualizer=None):
        with torch.no_grad():
            # 1. get source info
            src_info = self.src_info

            ori_img = cv_utils.read_cv2_img(tgt_path)
            if tgt_smpl is None:
                img_hmr = cv_utils.transform_img(ori_img, 224, transpose=True) * 2 - 1.0
                img_hmr = torch.FloatTensor(img_hmr).cuda()[None, ...]
                tgt_smpl = self.hmr(img_hmr)[-1]
            else:
                tgt_smpl = to_tensor(tgt_smpl).cuda()[None, ...]

            if t == 0 and cam_strategy == 'smooth':
                self.first_cam = tgt_smpl[:, 0:3].clone()

            # 2. compute tsf smpl
            tsf_smpl = self.swap_smpl(src_info['cam'], src_info['shape'], tgt_smpl, cam_strategy=cam_strategy)
            tsf_info = self.hmr.get_details(tsf_smpl)
            # add pose condition and face index map into source info
            tsf_info['cond'], tsf_info['fim'] = self.render.encode_fim(tsf_info['cam'],
                                                                       tsf_info['verts'], transpose=True)
            # add part condition into source info
            tsf_info['part'] = self.render.encode_front_fim(tsf_info['fim'], transpose=True)

            # 3. calculate syn front image and transformation flows
            ref_info = self.ref_info
            selected_part_id = self.PART_IDS['body']
            left_id = [i for i in self.PART_IDS['all'] if i not in selected_part_id]

            src_part_mask = (torch.sum(tsf_info['part'][:, left_id, ...], dim=1) != 0).byte()
            ref_part_mask = (torch.sum(tsf_info['part'][:, selected_part_id, ...], dim=1) != 0).byte()

            T_s = self.calculate_trans(src_info['bc_f2pts'], src_info['fim'], tsf_info['fim'], src_part_mask)
            T_r = self.calculate_trans(ref_info['bc_f2pts'], ref_info['fim'], tsf_info['fim'], ref_part_mask)

            tsf_s = self.model.transform(src_info['image'], T_s)
            tsf_r = self.model.transform(ref_info['image'], T_r)

            tsf_img = tsf_s * src_part_mask.float() + tsf_r * ref_part_mask.float()
            tsf_inputs = torch.cat([tsf_img, tsf_info['cond']], dim=1)

            preds = self.forward2(tsf_inputs, src_info['feats'], T_s, ref_info['feats'], T_r, src_info['bg'])

            if visualizer is not None:
                visualizer.vis_named_img('src', src_info['image'])
                visualizer.vis_named_img('ref', ref_info['image'])
                visualizer.vis_named_img('src_cond', src_info['cond'])
                visualizer.vis_named_img('ref_cond', ref_info['cond'])
                visualizer.vis_named_img('tsf_cond', tsf_info['cond'])
                visualizer.vis_named_img('tsf_s', tsf_s)
                visualizer.vis_named_img('tsf_r', tsf_r)
                visualizer.vis_named_img('tsf_img', tsf_img)
                visualizer.vis_named_img('preds', preds)
                visualizer.vis_named_img('src_part_mask', src_part_mask)
                visualizer.vis_named_img('ref_part_mask', ref_part_mask)

            return preds
예제 #6
0
    def inference(self,
                  tgt_paths,
                  tgt_smpls=None,
                  cam_strategy='smooth',
                  output_dir='',
                  visualizer=None,
                  verbose=True):
        length = len(tgt_paths)

        outputs = []
        bg_img = self.src_info['bg']
        src_encoder_outs, src_resnet_outs = self.src_info['feats']

        process_bar = tqdm(range(length)) if verbose else range(length)
        for t in process_bar:
            tgt_path = tgt_paths[t]
            tgt_smpl = tgt_smpls[t] if tgt_smpls is not None else None

            tsf_inputs = self.transfer_params(tgt_path,
                                              tgt_smpl,
                                              cam_strategy,
                                              t=t)

            tsf_color, tsf_mask = self.generator.inference(
                src_encoder_outs, src_resnet_outs, tsf_inputs,
                self.tsf_info['T'])
            preds = tsf_mask * bg_img + (1 - tsf_mask) * tsf_color

            if self._opt.front_warp:
                preds = self.warp_front(preds, self.tsf_info['tsf_img'],
                                        self.tsf_info['fim'], tsf_mask)

            if visualizer is not None:
                gt = cv_utils.transform_img(self.tsf_info['image'],
                                            image_size=self._opt.image_size,
                                            transpose=True)
                visualizer.vis_named_img('pred_' + cam_strategy, preds)
                visualizer.vis_named_img('gt',
                                         gt[None, ...],
                                         denormalize=False)

            preds = preds[0].permute(1, 2, 0)
            preds = preds.cpu().numpy()
            outputs.append(preds)

            if output_dir:
                filename = os.path.split(tgt_path)[-1]

                cv_utils.save_cv2_img(preds,
                                      os.path.join(output_dir,
                                                   'pred_' + filename),
                                      normalize=True)
                cv_utils.save_cv2_img(self.tsf_info['image'],
                                      os.path.join(output_dir,
                                                   'gt_' + filename),
                                      image_size=self._opt.image_size)

        return outputs
예제 #7
0
    def load_init_preds(self, pred_path):
        pred_img_name = os.path.split(pred_path)[-1]
        pred_img_path = os.path.join(self._opt.preds_img_folder, 'pred_' + pred_img_name)

        img = cv_utils.read_cv2_img(pred_img_path)
        img = cv_utils.transform_img(img, self._opt.image_size, transpose=True)
        img = img * 2 - 1

        return img
 def load_images(self, im_pairs):
     imgs = []
     for im_path in im_pairs:
         img = cv_utils.read_cv2_img(im_path)
         img = cv_utils.transform_img(img,
                                      self._opt.image_size,
                                      transpose=True)
         img = img * 2 - 1
         imgs.append(img)
     imgs = np.stack(imgs)
     return imgs
예제 #9
0
    def transfer_params(self,
                        tgt_path,
                        tgt_smpl=None,
                        cam_strategy='smooth',
                        t=0):
        # get source info
        src_info = self.src_info

        ori_img = cv_utils.read_cv2_img(tgt_path)
        if tgt_smpl is None:
            img_hmr = cv_utils.transform_img(ori_img, 224,
                                             transpose=True) * 2 - 1.0
            img_hmr = torch.tensor(img_hmr, dtype=torch.float32).cuda()[None,
                                                                        ...]
            tgt_smpl = self.hmr(img_hmr)
        else:
            tgt_smpl = torch.tensor(tgt_smpl, dtype=torch.float32).cuda()[None,
                                                                          ...]

        if t == 0 and cam_strategy == 'smooth':
            self.first_cam = tgt_smpl[:, 0:3].clone()

        # get transfer smpl
        tsf_smpl = self.swap_smpl(src_info['cam'],
                                  src_info['shape'],
                                  tgt_smpl,
                                  cam_strategy=cam_strategy)
        # transfer process, {'theta', 'cam', 'pose', 'shape', 'verts', 'j2d', 'j3d'}
        tsf_info = self.hmr.get_details(tsf_smpl)

        tsf_f2verts, tsf_fim, tsf_wim = self.render.render_fim_wim(
            tsf_info['cam'], tsf_info['verts'])
        # src_f2pts = src_f2verts[:, :, :, 0:2]
        tsf_info['fim'] = tsf_fim
        tsf_info['wim'] = tsf_wim
        tsf_info['cond'], _ = self.render.encode_fim(tsf_info['cam'],
                                                     tsf_info['verts'],
                                                     fim=tsf_fim,
                                                     transpose=True)
        # tsf_info['sil'] = util.morph((tsf_fim != -1).float(), ks=self._opt.ft_ks, mode='dilate')

        T = self.render.cal_bc_transform(src_info['p2verts'], tsf_fim, tsf_wim)
        tsf_img = F.grid_sample(src_info['img'], T)
        tsf_inputs = torch.cat([tsf_img, tsf_info['cond']], dim=1)

        # add target image to tsf info
        tsf_info['tsf_img'] = tsf_img
        tsf_info['image'] = ori_img
        tsf_info['T'] = T

        self.T = T
        self.tsf_info = tsf_info

        return tsf_inputs
예제 #10
0
    def transfer_params(self, tgt_path, tgt_smpl=None, cam_strategy='smooth', t=0):
        ori_img = cv_utils.read_cv2_img(tgt_path)
        if tgt_smpl is None:
            img_hmr = cv_utils.transform_img(ori_img, 224, transpose=True) * 2 - 1.0
            img_hmr = torch.tensor(img_hmr, dtype=torch.float32).cuda()[None, ...]
            tgt_smpl = self.hmr(img_hmr)
        else:
            if isinstance(tgt_smpl, np.ndarray):
                tgt_smpl = torch.tensor(tgt_smpl, dtype=torch.float32).cuda()[None, ...]

        tsf_inputs = self.transfer_params_by_smpl(tgt_smpl=tgt_smpl, cam_strategy=cam_strategy, t=t)
        self.tsf_info['image'] = ori_img

        return tsf_inputs
예제 #11
0
    def personalize(self, src_path, src_smpl=None, output_path='', visualizer=None):

        ori_img = cv_utils.read_cv2_img(src_path)

        # resize image and convert the color space from [0, 255] to [-1, 1]
        img = cv_utils.transform_img(ori_img, self._opt.image_size, transpose=True) * 2 - 1.0
        img = torch.tensor(img, dtype=torch.float32).cuda()[None, ...]

        if src_smpl is None:
            img_hmr = cv_utils.transform_img(ori_img, 224, transpose=True) * 2 - 1.0
            img_hmr = torch.tensor(img_hmr, dtype=torch.float32).cuda()[None, ...]
            src_smpl = self.hmr(img_hmr)
        else:
            src_smpl = torch.tensor(src_smpl, dtype=torch.float32).cuda()[None, ...]

        # source process, {'theta', 'cam', 'pose', 'shape', 'verts', 'j2d', 'j3d'}
        src_info = self.hmr.get_details(src_smpl)
        src_f2verts, src_fim, src_wim = self.render.render_fim_wim(src_info['cam'], src_info['verts'])
        # src_f2pts = src_f2verts[:, :, :, 0:2]
        src_info['fim'] = src_fim
        src_info['wim'] = src_wim
        src_info['cond'], _ = self.render.encode_fim(src_info['cam'], src_info['verts'], fim=src_fim, transpose=True)
        src_info['f2verts'] = src_f2verts
        src_info['p2verts'] = src_f2verts[:, :, :, 0:2]
        src_info['p2verts'][:, :, :, 1] *= -1

        if self._opt.only_vis:
            src_info['p2verts'] = self.render.get_vis_f2pts(src_info['p2verts'], src_fim)

        src_info['part'], _ = self.render.encode_fim(src_info['cam'], src_info['verts'],
                                                     fim=src_fim, transpose=True, map_fn=self.part_fn)
        # add image to source info
        src_info['img'] = img
        src_info['image'] = ori_img

        # 2. process the src inputs
        if self.detector is not None:
            bbox, body_mask = self.detector.inference(img[0])
            bg_mask = 1 - body_mask
        else:
            bg_mask = util.morph(src_info['cond'][:, -1:, :, :], ks=self._opt.bg_ks, mode='erode')
            body_mask = 1 - bg_mask

        if self._opt.bg_model != 'ORIGINAL':
            src_info['bg'] = self.bgnet(img, masks=body_mask, only_x=True)
        else:
            incomp_img = img * bg_mask
            bg_inputs = torch.cat([incomp_img, bg_mask], dim=1)
            img_bg = self.bgnet(bg_inputs)
            src_info['bg_inputs'] = bg_inputs
            # src_info['bg'] = img_bg
            src_info['bg'] = incomp_img + img_bg * body_mask

        ft_mask = 1 - util.morph(src_info['cond'][:, -1:, :, :], ks=self._opt.ft_ks, mode='erode')
        src_inputs = torch.cat([img * ft_mask, src_info['cond']], dim=1)

        src_info['feats'] = self.generator.encode_src(src_inputs)
        src_info['src_inputs'] = src_inputs

        src_info = src_info

        # if visualizer is not None:
        #     self.visualize(visualizer, src=img, bg=src_info['bg'])

        if output_path:
            cv_utils.save_cv2_img(src_info['image'], output_path, image_size=self._opt.image_size)

        return src_info