Ejemplo n.º 1
0
def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch,
                 image_name, save_path, opt):

    # save_path = os.path.join('./notebooks/output/demo_results-wild', ids[f_id][0])
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    dr_render = opendr_render(ratio=opt.ratio, color=opt.color)

    focal_length = 5000.
    orig_size = 224.

    if pred_uv is not None:
        iuv_img = iuv_map2img(*pred_uv)

    for draw_i in range(len(cam_pred)):
        err_val = '{:06d}_'.format(int(10 * vert_errors_batch[draw_i]))
        draw_name = err_val + image_name[draw_i]
        # cam_pred_opendr = cam_pred[draw_i]
        # cam_pred_opendr[:, 2] = - cam_pred_opendr[:, 2]
        K = np.array([[focal_length, 0., orig_size / 2.],
                      [0., focal_length, orig_size / 2.], [0., 0., 1.]])

        img_orig, img_resized, img_smpl, render_smpl_rgba = dr_render.render(
            image[draw_i], cam_pred[draw_i], K, vert_pred[draw_i], face,
            draw_name[:-4])

        ones_img = np.ones(img_smpl.shape[:2]) * 255
        ones_img = ones_img[:, :, None]
        img_smpl_rgba = np.concatenate((img_smpl * 255, ones_img), axis=2)
        img_resized_rgba = np.concatenate((img_resized * 255, ones_img),
                                          axis=2)

        render_img = np.concatenate(
            (img_resized_rgba, img_smpl_rgba, render_smpl_rgba * 255), axis=1)
        render_img[render_img < 0] = 0
        render_img[render_img > 255] = 255
        matplotlib.image.imsave(
            os.path.join(save_path, draw_name[:-4] + '.png'),
            render_img.astype(np.uint8))

        if pred_uv is not None:
            # estimated global IUV
            global_iuv = iuv_img[draw_i].cpu().numpy()
            global_iuv = np.transpose(global_iuv, (1, 2, 0))
            global_iuv = resize(global_iuv, img_resized.shape[:2])
            global_iuv[global_iuv > 1] = 1
            global_iuv[global_iuv < 0] = 0
            matplotlib.image.imsave(
                os.path.join(save_path, 'pred_uv_' + draw_name[:-4] + '.png'),
                global_iuv)
Ejemplo n.º 2
0
    def _forward(self, in_dict):

        if type(in_dict) is not dict:
            in_dict = {'img': in_dict, 'pretrain_mode': False, 'vis_on': False, 'dataset_name': ''}

        image = in_dict['img']
        gt_pose = in_dict['opt_pose'] if 'opt_pose' in in_dict else None  # SMPL pose parameters
        gt_betas = in_dict['opt_betas'] if 'opt_betas' in in_dict else None  # SMPL beta parameters
        target_kps = in_dict['target_kps'] if 'target_kps' in in_dict else None
        target_kps3d = in_dict['target_kps3d'] if 'target_kps3d' in in_dict else None
        has_iuv = in_dict['has_iuv'].byte() if 'has_iuv' in in_dict else None
        has_dp = in_dict['has_dp'].byte() if 'has_dp' in in_dict else None
        has_kp3d = in_dict['has_pose_3d'].byte() if 'has_pose_3d' in in_dict else None  # flag that indicates whether 3D pose is valid
        target_smpl_kps = in_dict['target_smpl_kps'] if 'target_smpl_kps' in in_dict else None
        target_verts = in_dict['target_verts'] if 'target_verts' in in_dict else None
        valid_fit = in_dict['valid_fit'] if 'valid_fit' in in_dict else None

        batch_size = image.shape[0]

        if gt_pose is not None:
            gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24 * 3 * 3)
            target_cam = in_dict['target_cam']
            target = torch.cat([target_cam, gt_betas, gt_rotmat], dim=1)
            uv_image_gt = torch.zeros((batch_size, 3, cfg.DANET.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE)).to(image.device)
            if torch.sum(has_iuv) > 0:
                uv_image_gt[has_iuv] = self.iuv2smpl.verts2uvimg(target_verts[has_iuv], cam=target_cam[has_iuv])  # [B, 3, 56, 56]
        else:
            target = None

        # target_iuv_dp = in_dict['target_iuv_dp'] if 'target_iuv_dp' in in_dict else None
        target_iuv_dp = in_dict['dp_dict'] if 'dp_dict' in in_dict else None

        if 'target_kps_coco' in in_dict:
            target_kps = in_dict['target_kps_coco']

        return_dict = {}
        return_dict['losses'] = {}
        return_dict['metrics'] = {}
        return_dict['visualization'] = {}
        return_dict['prediction'] = {}

        if cfg.DANET.INPUT_MODE in ['iuv_gt']:
            if cfg.DANET.DECOMPOSED:
                uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
                uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
            else:
                uv_return_dict = {}
                uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
        elif cfg.DANET.INPUT_MODE in ['iuv_gt_feat']:
            uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
            uv_return_dict['uvia_pred'] = iuv_img2map(uv_image_gt)
        elif cfg.DANET.INPUT_MODE in ['feat']:
            uv_return_dict = self.img2iuv(image, None, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp)
        else:
            uv_return_dict = self.img2iuv(image, uv_image_gt, target_smpl_kps, pretrained=in_dict['pretrain_mode'], uvia_dp_gt=target_iuv_dp, has_iuv=has_iuv, has_dp=has_dp)

        u_pred, v_pred, index_pred, ann_pred = uv_return_dict['uvia_pred']
        if self.training and cfg.DANET.PART_IUV_ZERO > 0:
            zero_idxs = []
            for bs in range(u_pred.shape[0]):
                zero_idxs.append([int(i) + 1 for i in torch.nonzero(torch.rand(24) < cfg.DANET.PART_IUV_ZERO)])

        if self.training and cfg.DANET.PART_IUV_ZERO > 0:
            for bs in range(len(zero_idxs)):
                u_pred[bs, zero_idxs[bs]] *= 0
                v_pred[bs, zero_idxs[bs]] *= 0
                index_pred[bs, zero_idxs[bs]] *= 0

        u_pred_cl, v_pred_cl, index_pred_cl, ann_pred_cl = iuvmap_clean(u_pred, v_pred, index_pred, ann_pred)

        iuv_pred_clean = [u_pred_cl.detach(), v_pred_cl.detach(), index_pred_cl.detach(), ann_pred_cl.detach()]
        return_dict['visualization']['iuv_pred'] = iuv_pred_clean

        if in_dict['vis_on']:
            uvi_pred_clean = [u_pred_cl.detach(), v_pred_cl.detach(), index_pred_cl.detach(), ann_pred_cl.detach()]
            return_dict['visualization']['pred_uv'] = iuv_map2img(*uvi_pred_clean)
            return_dict['visualization']['gt_uv'] = uv_image_gt
            if 'stn_kps_pred' in uv_return_dict:
                return_dict['visualization']['stn_kps_pred'] = uv_return_dict['stn_kps_pred']

            # index_pred_cl shape:  2, 25, 56, 56
            return_dict['visualization']['index_sum'] = [torch.sum(index_pred_cl[:, 1:].detach()).unsqueeze(0), np.prod(index_pred_cl[:, 0].shape)]

            for key in ['skps_hm_pred', 'skps_hm_gt']:
                if key in uv_return_dict:
                    return_dict['visualization'][key] = torch.max(uv_return_dict[key], dim=1)[0].unsqueeze(1)
                    return_dict['visualization'][key][return_dict['visualization'][key] > 1] = 1.
                    skps_hm_vis = uv_return_dict[key]
                    skps_hm_vis = skps_hm_vis.reshape((skps_hm_vis.shape[0], skps_hm_vis.shape[1], -1))
                    skps_hm_vis = F.softmax(skps_hm_vis, 2)
                    skps_hm_vis = skps_hm_vis.reshape(skps_hm_vis.shape[0], skps_hm_vis.shape[1],
                                                      cfg.DANET.HEATMAP_SIZE, cfg.DANET.HEATMAP_SIZE)
                    return_dict['visualization'][key + '_soft'] = torch.sum(skps_hm_vis, dim=1).unsqueeze(1)
            # for key in ['part_uvi_pred', 'part_uvi_gt']:
            for key in ['part_uvi_gt']:
                if key in uv_return_dict:
                    part_uvi_pred_vis = uv_return_dict[key][0]
                    p_uvi_vis = []
                    for i in range(part_uvi_pred_vis.size(0)):
                        p_u_vis, p_v_vis, p_i_vis = [part_uvi_pred_vis[i, uvi].unsqueeze(0) for uvi in range(3)]
                        if p_u_vis.size(1) == 25:
                            p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach())
                        else:
                            p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(),
                                                         ind_mapping=[0] + self.img2iuv.dp2smpl_mapping[i])
                        # p_uvi_vis_i = uvmap_vis(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), self.img2iuv.dp2smpl_mapping[i])
                        p_uvi_vis.append(p_uvi_vis_i)
                    return_dict['visualization'][key] = torch.cat(p_uvi_vis, dim=0)

        if not in_dict['pretrain_mode']:

            iuv_map = torch.cat([u_pred_cl, v_pred_cl, index_pred_cl], dim=1)

            if cfg.DANET.INPUT_MODE in ['iuv_gt', 'iuv_gt_feat'] and 'part_iuv_gt' in uv_return_dict:
                part_iuv_map = uv_return_dict['part_iuv_gt']
                if self.training and cfg.DANET.PART_IUV_ZERO > 0:
                    for bs in range(len(zero_idxs)):
                        zero_channel = []
                        for zero_i in zero_idxs[bs]:
                            zero_channel.extend(
                                [(i, m_i + 1) for i, mapping in enumerate(self.img2iuv.dp2smpl_mapping) for m_i, map_idx in
                                 enumerate(mapping) if map_idx == zero_i])
                        zero_dp_i = [iterm[0] for iterm in zero_channel]
                        zero_p_i = [iterm[1] for iterm in zero_channel]
                        part_iuv_map[bs, zero_dp_i, :, zero_p_i] *= 0

                part_index_map = part_iuv_map[:, :, 2]
            elif 'part_iuv_pred' in uv_return_dict:
                part_iuv_pred = uv_return_dict['part_iuv_pred']
                if self.training and cfg.DANET.PART_IUV_ZERO > 0:
                    for bs in range(len(zero_idxs)):
                        zero_channel = []
                        for zero_i in zero_idxs[bs]:
                            zero_channel.extend(
                                [(i, m_i + 1) for i, mapping in enumerate(self.img2iuv.dp2smpl_mapping) for m_i, map_idx in
                                 enumerate(mapping) if map_idx == zero_i])
                        zero_dp_i = [iterm[0] for iterm in zero_channel]
                        zero_p_i = [iterm[1] for iterm in zero_channel]
                        part_iuv_pred[bs, zero_dp_i, :, zero_p_i] *= 0

                part_iuv_map = []
                for p_ind in range(part_iuv_pred.size(1)):
                    p_u_pred, p_v_pred, p_index_pred = [part_iuv_pred[:, p_ind, iuv] for iuv in range(3)]
                    p_u_map, p_v_map, p_i_map, _ = iuvmap_clean(p_u_pred, p_v_pred, p_index_pred)
                    p_iuv_map = torch.stack([p_u_map, p_v_map, p_i_map], dim=1)
                    part_iuv_map.append(p_iuv_map)
                part_iuv_map = torch.stack(part_iuv_map, dim=1)
                part_index_map = part_iuv_map[:, :, 2]

            else:
                part_iuv_map = None
                part_index_map = None

            return_dict['visualization']['part_iuv_pred'] = part_iuv_map

            if 'part_featmaps' in uv_return_dict:
                part_feat_map = uv_return_dict['part_featmaps']
            else:
                part_feat_map = None

            if cfg.DANET.INPUT_MODE == 'feat':
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'feat': uv_return_dict['global_featmaps']},
                                                 'part_iuv_map': {'pfeat': part_feat_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })
            elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'iuv': iuv_map, 'feat': uv_return_dict['global_featmaps']},
                                                 'part_iuv_map': {'piuv': part_iuv_map, 'pfeat': part_feat_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })
            elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
                smpl_return_dict = self.iuv2smpl({'iuv_map': iuv_map,
                                                 'part_iuv_map': part_iuv_map,
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d,
                                                 'has_smpl': valid_fit
                                                  })
            elif cfg.DANET.INPUT_MODE == 'seg':
                # REMOVE _.detach
                smpl_return_dict = self.iuv2smpl({'iuv_map': {'index': index_pred_cl},
                                                 'part_iuv_map': {'pindex': part_index_map},
                                                 'target': target,
                                                 'target_kps': target_kps,
                                                 'target_verts': target_verts,
                                                 'target_kps3d': target_kps3d,
                                                 'has_kp3d': has_kp3d
                                                  })

            if in_dict['vis_on'] and part_index_map is not None:
                # part_index_map: 2, 24, 7, 56, 56
                return_dict['visualization']['p_index_sum'] = [torch.sum(part_index_map[:, :, 1:].detach()).unsqueeze(0),
                                                               np.prod(part_index_map[:, :, 0].shape)]

            if in_dict['vis_on'] and part_iuv_map is not None:
                part_uvi_pred_vis = part_iuv_map[0]
                p_uvi_vis = []
                for i in range(part_uvi_pred_vis.size(0)):
                    p_u_vis, p_v_vis, p_i_vis = [part_uvi_pred_vis[i, uvi].unsqueeze(0) for uvi in range(3)]
                    if p_u_vis.size(1) == 25:
                        p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach())
                    else:
                        p_uvi_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(),
                                                     ind_mapping=[0] + self.img2iuv.dp2smpl_mapping[i])
                    # p_uvi_vis_i = uvmap_vis(p_u_vis.detach(), p_v_vis.detach(), p_i_vis.detach(), self.img2iuv.dp2smpl_mapping[i])
                    p_uvi_vis.append(p_uvi_vis_i)
                return_dict['visualization']['part_uvi_pred'] = torch.cat(p_uvi_vis, dim=0)

        for key_name in ['losses', 'metrics', 'visualization', 'prediction']:
            if key_name in uv_return_dict:
                return_dict[key_name].update(uv_return_dict[key_name])
            if not in_dict['pretrain_mode']:
                return_dict[key_name].update(smpl_return_dict[key_name])

        # pytorch0.4 bug on gathering scalar(0-dim) tensors
        for k, v in return_dict['losses'].items():
            if len(v.shape) == 0:
                return_dict['losses'][k] = v.unsqueeze(0)
        for k, v in return_dict['metrics'].items():
            if len(v.shape) == 0:
                return_dict['metrics'][k] = v.unsqueeze(0)

        return return_dict
Ejemplo n.º 3
0
    def visualize(self, it, target, stage, preds, losses=None):

        theta = preds['smpl_out'][-1]['theta']
        pred_verts = preds['smpl_out'][-1]['verts'].cpu().numpy(
        ) if 'verts' in preds['smpl_out'][-1] else None
        cam_pred = theta[:, :3].detach()

        dp_out = preds['dp_out'][-1] if cfg.MODEL.PyMAF.AUX_SUPV_ON else None

        images = target['img']
        images = images * torch.tensor(
            [0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
        images = images + torch.tensor(
            [0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
        imgs_np = images.cpu().numpy()

        vis_img_full = []
        vis_n = min(len(theta), 16)
        vis_img = []
        for b in range(vis_n):
            cam_t = cam_pred[b].cpu().numpy()
            smpl_verts = target['verts'][b].cpu().numpy()
            smpl_verts_pred = pred_verts[b] if pred_verts is not None else None

            render_imgs = []

            img_vis = np.transpose(imgs_np[b], (1, 2, 0)) * 255
            img_vis = img_vis.astype(np.uint8)

            render_imgs.append(img_vis)

            render_imgs.append(
                self.renderer(smpl_verts,
                              self.smpl.faces,
                              image=img_vis,
                              cam=cam_t,
                              addlight=True))

            if cfg.MODEL.PyMAF.AUX_SUPV_ON:
                if stage == 'train':
                    iuv_image_gt = target['iuv_image_gt'][b].detach().cpu(
                    ).numpy()
                    iuv_image_gt = np.transpose(iuv_image_gt, (1, 2, 0)) * 255
                    iuv_image_gt_resized = resize(
                        iuv_image_gt, (img_vis.shape[0], img_vis.shape[1]),
                        preserve_range=True,
                        anti_aliasing=True)
                    render_imgs.append(iuv_image_gt_resized.astype(np.uint8))

                pred_iuv_list = [dp_out['predict_u'][b:b+1], dp_out['predict_v'][b:b+1], \
                                    dp_out['predict_uv_index'][b:b+1], dp_out['predict_ann_index'][b:b+1]]
                iuv_image_pred = iuv_map2img(
                    *pred_iuv_list)[0].detach().cpu().numpy()
                iuv_image_pred = np.transpose(iuv_image_pred, (1, 2, 0)) * 255
                iuv_image_pred_resized = resize(
                    iuv_image_pred, (img_vis.shape[0], img_vis.shape[1]),
                    preserve_range=True,
                    anti_aliasing=True)
                render_imgs.append(iuv_image_pred_resized.astype(np.uint8))

            if smpl_verts_pred is not None:
                render_imgs.append(
                    self.renderer(smpl_verts_pred,
                                  self.smpl.faces,
                                  image=img_vis,
                                  cam=cam_t,
                                  addlight=True))

            img = np.concatenate(render_imgs, axis=1)
            img = np.transpose(img, (2, 0, 1))
            vis_img.append(img)

        vis_img_full.append(np.concatenate(vis_img, axis=1))

        vis_img_full = np.concatenate(vis_img_full, axis=-1)
        if stage == 'train':
            self.summary_writer.add_image('{}/mesh_pred'.format(stage),
                                          vis_img_full, it)
        else:
            self.summary_writer.add_image('{}/mesh_pred_{}'.format(stage, it),
                                          vis_img_full, self.epoch_count)
Ejemplo n.º 4
0
    def forward(self, body_iuv, limb_iuv):

        return_dict = {}
        return_dict['visualization'] = {}
        return_dict['losses'] = {}

        if cfg.DANET.INPUT_MODE == 'rgb':
            map_channels = body_iuv.size(1) / 3
            body_u, body_v, body_i = body_iuv[:, :
                                              map_channels], body_iuv[:,
                                                                      map_channels:
                                                                      2 *
                                                                      map_channels], body_iuv[:,
                                                                                              2
                                                                                              *
                                                                                              map_channels:
                                                                                              3
                                                                                              *
                                                                                              map_channels]
            global_para, global_feat = self.body_net(
                iuv_map2img(body_u, body_v, body_i))
        elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
            global_para, global_feat = self.body_net(body_iuv)
        elif cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
            global_para, global_feat = self.body_net(
                torch.cat([body_iuv['iuv'], body_iuv['feat']], dim=1))
        elif cfg.DANET.INPUT_MODE == 'feat':
            global_para, global_feat = self.body_net(body_iuv['feat'])
        elif cfg.DANET.INPUT_MODE == 'seg':
            global_para, global_feat = self.body_net(body_iuv['index'])

        global_para += self.mean_cam_shape

        if cfg.DANET.INPUT_MODE in ['iuv_feat', 'feat', 'iuv_gt_feat']:
            nbs = limb_iuv['pfeat'].size(0)
            limb_mapsize = limb_iuv['pfeat'].size(-1)
        elif cfg.DANET.INPUT_MODE in ['seg']:
            nbs = limb_iuv['pindex'].size(0)
            limb_mapsize = limb_iuv['pindex'].size(-1)
        else:
            nbs = limb_iuv.size(0)
            limb_mapsize = limb_iuv.size(-1)

        if cfg.DANET.INPUT_MODE in ['iuv_feat', 'iuv_gt_feat']:
            limb_iuv_stacked = limb_iuv['piuv'].view(nbs * 24, -1,
                                                     limb_mapsize,
                                                     limb_mapsize)
            limb_feat_stacked = limb_iuv['pfeat'].view(nbs * 24, -1,
                                                       limb_mapsize,
                                                       limb_mapsize)
            _, limb_feat = self.limb_net(
                torch.cat([limb_iuv_stacked, limb_feat_stacked], dim=1))
        elif cfg.DANET.INPUT_MODE in ['iuv', 'iuv_gt']:
            limb_iuv_stacked = limb_iuv.view(nbs * 24, -1, limb_mapsize,
                                             limb_mapsize)
            _, limb_feat = self.limb_net(limb_iuv_stacked)
        if cfg.DANET.INPUT_MODE == 'feat':
            limb_feat_stacked = limb_iuv['pfeat'].view(nbs * 24, -1,
                                                       limb_mapsize,
                                                       limb_mapsize)
            _, limb_feat = self.limb_net(limb_feat_stacked)
        if cfg.DANET.INPUT_MODE == 'seg':
            limb_feat_stacked = limb_iuv['pindex'].view(
                nbs * 24, -1, limb_mapsize, limb_mapsize)
            _, limb_feat = self.limb_net(limb_feat_stacked)

        limb_feat = limb_feat['x4']
        limb_feat = self.limb_reslayer(
            limb_feat.view(nbs, -1, limb_feat.size(-2), limb_feat.size(-1)))

        rot_feats = limb_feat.view(nbs, 24, -1, limb_feat.size(-2),
                                   limb_feat.size(-1))

        if cfg.DANET.REFINE_STRATEGY == 'lstm_direct':

            return_dict['joint_rotation'] = []

            local_para = self.pose_regressors[0](rot_feats.view(
                rot_feats.size(0), 24 * rot_feats.size(2), 1,
                1)).view(nbs, 24, -1)
            smpl_pose = local_para.view(local_para.size(0), -1)
            return_dict['joint_rotation'].append(smpl_pose)

            for s_i in range(cfg.DANET.REFINEMENT.STACK_NUM):
                pos_feats = {}
                for i in range(24):
                    pos_feats[i] = rot_feats[:, i]

                pos_feats_refined = {}
                for br in range(len(self.limb_branch_lstm)):
                    pos_feat_in = torch.stack(
                        [pos_feats[ind] for ind in self.limb_branch_lstm[br]],
                        dim=1)
                    pos_feat_in = pos_feat_in.squeeze(-1).squeeze(-1)
                    if br == 0:
                        lstm_out, hidden_feat = self.limb_lstm[s_i][0](
                            pos_feat_in)
                    elif br == 1:
                        lstm_out, _ = self.limb_lstm[s_i][0](pos_feat_in,
                                                             hidden_feat)
                    elif br in [2, 3]:
                        lstm_out, _ = self.limb_lstm[s_i][br - 1](pos_feat_in,
                                                                  hidden_feat)
                    else:
                        lstm_out, _ = self.limb_lstm[s_i][br - 1](pos_feat_in)
                    for i, ind in enumerate(self.limb_branch_lstm[br]):
                        if ind == 0 and br != 0:
                            continue
                        pos_feats_refined[ind] = lstm_out[:, i].unsqueeze(
                            -1).unsqueeze(-1)

                # update
                for i in range(24):
                    pos_feats[i] = pos_feats[i].repeat(
                        1, 2, 1, 1) + pos_feats_refined[i]

                refined_feat = torch.stack([pos_feats[i] for i in range(24)],
                                           dim=1)
                part_feats = refined_feat.view(refined_feat.size(0),
                                               24 * refined_feat.size(2), 1, 1)

                local_para = self.pose_regressors[s_i + 1](part_feats).view(
                    nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)

        elif cfg.DANET.REFINE_STRATEGY == 'lstm':

            return_dict['joint_position'] = []
            return_dict['joint_rotation'] = []

            if self.training:
                local_para = self.pose_regressors[0](rot_feats.view(
                    rot_feats.size(0), 24 * rot_feats.size(2), 1,
                    1)).view(nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)
                smpl_pose += self.mean_pose
                if cfg.DANET.USE_6D_ROT:
                    smpl_pose = rot6d_to_rotmat(smpl_pose).view(
                        local_para.size(0), -1)
                return_dict['joint_rotation'].append(smpl_pose)

            rot_feats_before = rot_feats

            for s_i in range(cfg.DANET.REFINEMENT.STACK_NUM):
                pos_feats = {}

                pos_feats[0] = rot_feats_before[:, 0]
                for br in range(len(self.limb_branch)):
                    for ind in self.limb_branch[br]:
                        p_ind = self.smpl_parents[0][ind]
                        pos_rot_feat_cat = torch.cat(
                            [pos_feats[p_ind], rot_feats_before[:, p_ind]],
                            dim=1)
                        pos_feats[ind] = self.rot2pos[s_i][ind](
                            pos_rot_feat_cat)

                if self.training:
                    if cfg.DANET.JOINT_POSITION_WEIGHTS > 0 and cfg.DANET.REFINEMENT.POS_INTERSUPV:
                        coord_feats = torch.cat(
                            [pos_feats[i] for i in range(24)], dim=1)
                        smpl_coord = self.coord_regressors[s_i](
                            coord_feats).view(nbs, 24, -1)
                        return_dict['joint_position'].append(smpl_coord)

                pos_feats_refined = {}
                for br in range(len(self.limb_branch_lstm)):
                    pos_feat_in = torch.stack(
                        [pos_feats[ind] for ind in self.limb_branch_lstm[br]],
                        dim=1)
                    pos_feat_in = pos_feat_in.squeeze(-1).squeeze(-1)
                    if br == 0:
                        lstm_out, hidden_feat = self.limb_lstm[s_i][0](
                            pos_feat_in)
                    elif br == 1:
                        lstm_out, _ = self.limb_lstm[s_i][0](pos_feat_in,
                                                             hidden_feat)
                    elif br in [2, 3]:
                        lstm_out, _ = self.limb_lstm[s_i][br - 1](pos_feat_in,
                                                                  hidden_feat)
                    else:
                        lstm_out, _ = self.limb_lstm[s_i][br - 1](pos_feat_in)
                    for i, ind in enumerate(self.limb_branch_lstm[br]):
                        if ind == 0 and br != 0:
                            continue
                        pos_feats_refined[ind] = lstm_out[:, i].unsqueeze(
                            -1).unsqueeze(-1)

                # update
                for i in range(24):
                    pos_feats[i] = pos_feats[i].repeat(
                        1, 2, 1, 1) + pos_feats_refined[i]

                if self.training:
                    if cfg.DANET.JOINT_POSITION_WEIGHTS > 0 and cfg.DANET.REFINEMENT.POS_INTERSUPV:
                        coord_feats = torch.cat(
                            [pos_feats[i] for i in range(24)], dim=1)
                        smpl_coord = self.coord_regressors[s_i + 1](
                            coord_feats).view(nbs, 24, -1)
                        return_dict['joint_position'].append(smpl_coord)

                tri_pos_feats = [
                    torch.cat([
                        pos_feats[self.smpl_parents[0][i]], pos_feats[i],
                        pos_feats[self.smpl_children[1][i]]
                    ],
                              dim=1) for i in range(24)
                ]
                tri_pos_feats = torch.cat(tri_pos_feats, dim=0)
                tran_rot_feats = self.pos2rot[s_i](tri_pos_feats)
                tran_rot_feats = tran_rot_feats.view(24, nbs, -1,
                                                     tran_rot_feats.size(-2),
                                                     tran_rot_feats.size(-1))
                tran_rot_feats = tran_rot_feats.transpose(0, 1)

                part_feats = tran_rot_feats.contiguous().view(
                    tran_rot_feats.size(0), 24 * tran_rot_feats.size(2), 1, 1)

                local_para = self.pose_regressors[s_i + 1](part_feats).view(
                    nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)
                smpl_pose += self.mean_pose
                if cfg.DANET.USE_6D_ROT:
                    smpl_pose = rot6d_to_rotmat(smpl_pose).view(
                        local_para.size(0), -1)

        elif cfg.DANET.REFINE_STRATEGY == 'gcn':

            return_dict['joint_position'] = []
            return_dict['joint_rotation'] = []

            if self.training:
                local_para = self.pose_regressors[0](rot_feats.view(
                    rot_feats.size(0), 24 * rot_feats.size(2), 1,
                    1)).view(nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)
                smpl_pose += self.mean_pose
                if cfg.DANET.USE_6D_ROT:
                    smpl_pose = rot6d_to_rotmat(smpl_pose).view(
                        local_para.size(0), -1)
                return_dict['joint_rotation'].append(smpl_pose)

            rot_feats_before = rot_feats

            rot_feats_init = rot_feats_before.squeeze(-1).squeeze(-1)
            pos_feats_init = self.r2p_gcn(rot_feats_init, self.r2p_A[0])

            if self.training:
                if cfg.DANET.JOINT_POSITION_WEIGHTS > 0 and cfg.DANET.REFINEMENT.POS_INTERSUPV:
                    coord_feats0 = pos_feats_init.unsqueeze(2).view(
                        pos_feats_init.size(0),
                        pos_feats_init.size(-1) * 24, 1, 1)
                    smpl_coord0 = self.coord_regressors[0](coord_feats0).view(
                        nbs, 24, -1)
                    return_dict['joint_position'].append(smpl_coord0)

            if cfg.DANET.REFINEMENT.REFINE_ON:
                graph_A = self.A_mask * self.edge_act(self.edge_importance)
                norm_graph_A = normalize_undigraph(self.I_n[0] + graph_A)[0]
                l_pos_feat = self.refine_gcn(pos_feats_init, norm_graph_A)
                l_pos_feat = pos_feats_init + l_pos_feat

                pos_feats_refined = l_pos_feat

                if self.training:
                    if cfg.DANET.JOINT_POSITION_WEIGHTS > 0 and cfg.DANET.REFINEMENT.POS_INTERSUPV:
                        coord_feats1 = pos_feats_refined.unsqueeze(2).view(
                            pos_feats_refined.size(0),
                            pos_feats_refined.size(-1) * 24, 1, 1)
                        smpl_coord1 = self.coord_regressors[1](
                            coord_feats1).view(nbs, 24, -1)
                        return_dict['joint_position'].append(smpl_coord1)
            else:
                pos_feats_refined = pos_feats_init

            rot_feats_refined = self.p2r_gcn(pos_feats_refined, self.p2r_A[0])

            tran_rot_feats = rot_feats_refined.unsqueeze(-1).unsqueeze(-1)
            part_feats = tran_rot_feats.view(tran_rot_feats.size(0),
                                             24 * tran_rot_feats.size(2), 1, 1)

            local_para = self.pose_regressors[-1](part_feats).view(nbs, 24, -1)
            smpl_pose = local_para.view(local_para.size(0), -1)
            smpl_pose += self.mean_pose
            if cfg.DANET.USE_6D_ROT:
                smpl_pose = rot6d_to_rotmat(smpl_pose).view(
                    local_para.size(0), -1)

        elif cfg.DANET.REFINE_STRATEGY == 'gcn_direct':

            return_dict['joint_position'] = []
            return_dict['joint_rotation'] = []

            local_para = self.pose_regressors[0](rot_feats.view(
                rot_feats.size(0), 24 * rot_feats.size(2), 1,
                1)).view(nbs, 24, -1)
            smpl_pose = local_para.view(local_para.size(0), -1)

            if cfg.DANET.REFINEMENT.REFINE_ON:
                return_dict['joint_rotation'].append(smpl_pose)

                pos_feats_init = rot_feats.squeeze(-1).squeeze(-1)

                graph_A = self.A_mask * self.edge_act(self.edge_importance)
                norm_graph_A = normalize_undigraph(self.I_n[0] + graph_A)[0]
                l_pos_feat = self.refine_gcn(pos_feats_init, norm_graph_A)
                l_pos_feat = pos_feats_init + l_pos_feat

                pos_feats_refined = l_pos_feat
                tran_rot_feats = pos_feats_refined.unsqueeze(-1).unsqueeze(-1)

                part_feats = tran_rot_feats.view(tran_rot_feats.size(0),
                                                 24 * tran_rot_feats.size(2),
                                                 1, 1)

                local_para = self.pose_regressors[-1](part_feats).view(
                    nbs, 24, -1)
                smpl_pose = local_para.view(local_para.size(0), -1)

        para = torch.cat([global_para, smpl_pose], dim=1)

        return_dict['para'] = para

        return return_dict
Ejemplo n.º 5
0
def main():
    """Main function"""
    args = parse_args()
    args.batch_size = 1

    cfg_from_file(args.cfg_file)

    cfg.DANET.REFINEMENT = EasyDict(cfg.DANET.REFINEMENT)
    cfg.MSRES_MODEL.EXTRA = EasyDict(cfg.MSRES_MODEL.EXTRA)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    if cfg.DANET.SMPL_MODEL_TYPE == 'male':
        smpl_male = SMPL(path_config.SMPL_MODEL_DIR,
                         gender='male',
                         create_transl=False).to(device)
        smpl = smpl_male
    elif cfg.DANET.SMPL_MODEL_TYPE == 'neutral':
        smpl_neutral = SMPL(path_config.SMPL_MODEL_DIR,
                            create_transl=False).to(device)
        smpl = smpl_neutral
    elif cfg.DANET.SMPL_MODEL_TYPE == 'female':
        smpl_female = SMPL(path_config.SMPL_MODEL_DIR,
                           gender='female',
                           create_transl=False).to(device)
        smpl = smpl_female

    if args.use_opendr:
        from utils.renderer import opendr_render
        dr_render = opendr_render()

    # IUV renderer
    iuv_renderer = IUV_Renderer()

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    ### Model ###
    model = DaNet(args, path_config.SMPL_MEAN_PARAMS,
                  pretrained=False).to(device)

    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model'], strict=False)
    model.eval()

    img_path_list = [
        os.path.join(args.img_dir, name) for name in os.listdir(args.img_dir)
        if name.endswith('.jpg')
    ]
    for i, path in enumerate(img_path_list):

        image = Image.open(path).convert('RGB')
        img_id = path.split('/')[-1][:-4]

        image_tensor = torchvision.transforms.ToTensor()(image).unsqueeze(
            0).cuda()

        # run inference
        pred_dict = model.infer_net(image_tensor)
        para_pred = pred_dict['para']
        camera_pred = para_pred[:, 0:3].contiguous()
        betas_pred = para_pred[:, 3:13].contiguous()
        rotmat_pred = para_pred[:, 13:].contiguous().view(-1, 24, 3, 3)

        # input image
        image_np = image_tensor[0].cpu().numpy()
        image_np = np.transpose(image_np, (1, 2, 0))

        ones_np = np.ones(image_np.shape[:2]) * 255
        ones_np = ones_np[:, :, None]

        image_in_rgba = np.concatenate((image_np, ones_np), axis=2)

        # estimated global IUV
        global_iuv = iuv_map2img(
            *pred_dict['visualization']['iuv_pred'])[0].cpu().numpy()
        global_iuv = np.transpose(global_iuv, (1, 2, 0))
        global_iuv = resize(global_iuv, image_np.shape[:2])
        global_iuv_rgba = np.concatenate((global_iuv, ones_np), axis=2)

        # estimated patial IUV
        part_iuv_pred = pred_dict['visualization']['part_iuv_pred'][0]
        p_iuv_vis = []
        for i in range(part_iuv_pred.size(0)):
            p_u_vis, p_v_vis, p_i_vis = [
                part_iuv_pred[i, iuv].unsqueeze(0) for iuv in range(3)
            ]
            if p_u_vis.size(1) == 25:
                p_iuv_vis_i = iuv_map2img(p_u_vis.detach(), p_v_vis.detach(),
                                          p_i_vis.detach())
            else:
                p_iuv_vis_i = iuv_map2img(p_u_vis.detach(),
                                          p_v_vis.detach(),
                                          p_i_vis.detach(),
                                          ind_mapping=[0] +
                                          model.img2iuv.dp2smpl_mapping[i])
            p_iuv_vis.append(p_iuv_vis_i)
        part_iuv = torch.cat(p_iuv_vis, dim=0)
        part_iuv = make_grid(part_iuv, nrow=6, padding=0).cpu().numpy()
        part_iuv = np.transpose(part_iuv, (1, 2, 0))
        part_iuv_rgba = np.concatenate(
            (part_iuv, np.ones(part_iuv.shape[:2])[:, :, None] * 255), axis=2)

        # rendered IUV of the predicted SMPL model
        smpl_output = smpl(betas=betas_pred,
                           body_pose=rotmat_pred[:, 1:],
                           global_orient=rotmat_pred[:, 0].unsqueeze(1),
                           pose2rot=False)
        verts_pred = smpl_output.vertices
        render_iuv = iuv_renderer.verts2uvimg(verts_pred, camera_pred)
        render_iuv = render_iuv[0].cpu().numpy()

        render_iuv = np.transpose(render_iuv, (1, 2, 0))
        render_iuv = resize(render_iuv, image_np.shape[:2])

        img_render_iuv = image_np.copy()
        img_render_iuv[render_iuv > 0] = render_iuv[render_iuv > 0]

        img_render_iuv_rgba = np.concatenate((img_render_iuv, ones_np), axis=2)

        img_vis_list = [
            image_in_rgba, global_iuv_rgba, part_iuv_rgba, img_render_iuv_rgba
        ]

        if args.use_opendr:
            # visualize the predicted SMPL model using the opendr renderer
            K = iuv_renderer.K[0].cpu().numpy()
            _, _, img_smpl, smpl_rgba = dr_render.render(
                image_tensor[0].cpu().numpy(), camera_pred[0].cpu().numpy(), K,
                verts_pred.cpu().numpy(), smpl_neutral.faces)

            img_smpl_rgba = np.concatenate((img_smpl, ones_np), axis=2)
            img_vis_list.extend([img_smpl_rgba, smpl_rgba])

        img_vis = np.concatenate(img_vis_list, axis=1)
        img_vis[img_vis < 0.0] = 0.0
        img_vis[img_vis > 1.0] = 1.0
        imsave(os.path.join(args.out_dir, img_id + '_result.png'), img_vis)

    print('Demo results have been saved in {}.'.format(args.out_dir))