Beispiel #1
0
    def __init__(self, cfg):

        self.dtype = torch.float32

        makepath(cfg.work_dir, isfile=False)
        logger = makelogger(makepath(os.path.join(cfg.work_dir, 'V00.log'), isfile=True)).info
        self.logger = logger

        use_cuda = torch.cuda.is_available()
        if use_cuda:
            torch.cuda.empty_cache()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.coarse_net = CoarseNet().to(self.device)
        self.refine_net = RefineNet().to(self.device)

        self.cfg = cfg
        self.coarse_net.cfg = cfg

        if cfg.best_cnet is not None:
            self._get_cnet_model().load_state_dict(torch.load(cfg.best_cnet, map_location=self.device), strict=False)
            logger('Restored CoarseNet model from %s' % cfg.best_cnet)
        if cfg.best_rnet is not None:
            self._get_rnet_model().load_state_dict(torch.load(cfg.best_rnet, map_location=self.device), strict=False)
            logger('Restored RefineNet model from %s' % cfg.best_rnet)

        self.bps = torch.from_numpy(np.load(cfg.bps_dir)['basis']).to(self.dtype)
Beispiel #2
0
def vis_results(dorig, coarse_net, refine_net, rh_model , save=False, save_dir = None):

    with torch.no_grad():
        imw, imh = 1920, 780
        cols = len(dorig['bps_object'])
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        mvs = MeshViewers(window_width=imw, window_height=imh, shape=[1, cols], keepalive=True)

        drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
        verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

        _, h2o, _ = point2point_signed(verts_rh_gen_cnet, dorig['verts_object'].to(device))

        drec_cnet['trans_rhand_f'] = drec_cnet['transl']
        drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(drec_cnet['global_orient']).view(-1, 3, 3)
        drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(-1, 15, 3, 3)
        drec_cnet['verts_object'] = dorig['verts_object'].to(device)
        drec_cnet['h2o_dist']= h2o.abs()

        drec_rnet = refine_net(**drec_cnet)
        verts_rh_gen_rnet = rh_model(**drec_rnet).vertices


        for cId in range(0, len(dorig['bps_object'])):
            try:
                from copy import deepcopy
                meshes = deepcopy(dorig['mesh_object'])
                obj_mesh = meshes[cId]
            except:
                obj_mesh = points_to_spheres(to_cpu(dorig['verts_object'][cId]), radius=0.002, vc=name_to_rgb['green'])

            hand_mesh_gen_cnet = Mesh(v=to_cpu(verts_rh_gen_cnet[cId]), f=rh_model.faces, vc=name_to_rgb['pink'])
            hand_mesh_gen_rnet = Mesh(v=to_cpu(verts_rh_gen_rnet[cId]), f=rh_model.faces, vc=name_to_rgb['gray'])

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_cnet.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)

            hand_mesh_gen_cnet.reset_face_normals()
            hand_mesh_gen_rnet.reset_face_normals()

            # mvs[0][cId].set_static_meshes([hand_mesh_gen_cnet] + obj_mesh, blocking=True)
            mvs[0][cId].set_static_meshes([hand_mesh_gen_rnet,obj_mesh], blocking=True)

            if save:
                save_path = os.path.join(save_dir, str(cId))
                makepath(save_path)
                hand_mesh_gen_rnet.write_ply(filename=save_path + '/rh_mesh_gen_%d.ply' % cId)
                obj_mesh[0].write_ply(filename=save_path + '/obj_mesh_%d.ply' % cId)
Beispiel #3
0
    def eval(self):
        self.coarse_net.eval()
        self.refine_net.eval()
        ds_name = self.cfg.dataset_dir.split('/')[-1]

        total_error_cnet = {}
        total_error_rnet = {}
        for split, ds in [('val', self.ds_val), ('test', self.ds_test),
                          ('train', self.ds_train)]:

            mean_error_cnet = []
            mean_error_rnet = []
            with torch.no_grad():
                for dorig in ds:

                    dorig = {k: dorig[k].to(self.device) for k in dorig.keys()}

                    MESH_SCALER = 1000

                    drec_cnet = self.coarse_net(**dorig)
                    verts_hand_cnet = self.rhm_train(**drec_cnet).vertices

                    mean_error_cnet.append(
                        torch.mean(
                            torch.abs(dorig['verts_rhand'] - verts_hand_cnet) *
                            MESH_SCALER))

                    ########## refine net
                    params_rnet = self.params_rnet(dorig)
                    dorig.update(params_rnet)
                    drec_rnet = self.refine_net(**dorig)
                    verts_hand_mano = self.rhm_train(**drec_rnet).vertices

                    mean_error_rnet.append(
                        torch.mean(
                            torch.abs(dorig['verts_rhand'] - verts_hand_mano) *
                            MESH_SCALER))

            total_error_cnet[split] = {
                'v2v_mae': float(to_cpu(torch.stack(mean_error_cnet).mean()))
            }
            total_error_rnet[split] = {
                'v2v_mae': float(to_cpu(torch.stack(mean_error_rnet).mean()))
            }

        outpath = makepath(os.path.join(
            self.cfg.work_dir, 'evaluations', 'ds_%s' % ds_name,
            os.path.basename(self.cfg.best_cnet).replace(
                '.pt', '_CoarseNet.json')),
                           isfile=True)

        with open(outpath, 'w') as f:
            json.dump(total_error_cnet, f)

        with open(outpath.replace('.json', '_RefineNet.json'), 'w') as f:
            json.dump(total_error_rnet, f)

        return total_error_cnet, total_error_rnet
Beispiel #4
0
def get_meshes(dorig, coarse_net, refine_net, rh_model, save=False, save_dir=None):
    with torch.no_grad():

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
        verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

        _, h2o, _ = point2point_signed(verts_rh_gen_cnet, dorig['verts_object'].to(device))

        drec_cnet['trans_rhand_f'] = drec_cnet['transl']
        drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(drec_cnet['global_orient']).view(-1, 3, 3)
        drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(-1, 15, 3, 3)
        drec_cnet['verts_object'] = dorig['verts_object'].to(device)
        drec_cnet['h2o_dist'] = h2o.abs()

        drec_rnet = refine_net(**drec_cnet)
        verts_rh_gen_rnet = rh_model(**drec_rnet).vertices

        gen_meshes = []
        for cId in range(0, len(dorig['bps_object'])):
            try:
                obj_mesh = dorig['mesh_object'][cId]
            except:
                obj_mesh = points2sphere(points=to_cpu(dorig['verts_object'][cId]), radius=0.002, vc=name_to_rgb['yellow'])

            hand_mesh_gen_rnet = Mesh(vertices=to_cpu(verts_rh_gen_rnet[cId]), faces=rh_model.faces, vc=[245, 191, 177])

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)

            gen_meshes.append([obj_mesh, hand_mesh_gen_rnet])
            if save:
                save_path = os.path.join(save_dir, str(cId))
                makepath(save_path)
                hand_mesh_gen_rnet.export(filename=save_path + '/rh_mesh_gen_%d.ply' % cId)
                obj_mesh.export(filename=save_path + '/obj_mesh_%d.ply' % cId)

        return gen_meshes
Beispiel #5
0
    def fit(self, n_epochs=None, message=None):

        starttime = datetime.now().replace(microsecond=0)
        if n_epochs is None:
            n_epochs = self.cfg.n_epochs

        self.logger(
            'Started Training at %s for %d epochs' %
            (datetime.strftime(starttime, '%Y-%m-%d_%H:%M:%S'), n_epochs))
        if message is not None:
            self.logger(message)

        prev_lr_cnet = np.inf
        prev_lr_rnet = np.inf
        self.fit_cnet = True
        self.fit_rnet = True

        lr_scheduler_cnet = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_cnet, 'min')
        lr_scheduler_rnet = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_rnet, 'min')
        early_stopping_cnet = EarlyStopping(patience=8, trace_func=self.logger)
        early_stopping_rnet = EarlyStopping(patience=8, trace_func=self.logger)

        for epoch_num in range(1, n_epochs + 1):
            self.logger('--- starting Epoch # %03d' % epoch_num)

            train_loss_dict_cnet, train_loss_dict_rnet = self.train()
            eval_loss_dict_cnet, eval_loss_dict_rnet = self.evaluate()

            if self.fit_cnet:

                lr_scheduler_cnet.step(eval_loss_dict_cnet['loss_total'])
                cur_lr_cnet = self.optimizer_cnet.param_groups[0]['lr']

                if cur_lr_cnet != prev_lr_cnet:
                    self.logger(
                        '--- CoarseNet learning rate changed from %.2e to %.2e ---'
                        % (prev_lr_cnet, cur_lr_cnet))
                    prev_lr_cnet = cur_lr_cnet

                with torch.no_grad():
                    eval_msg = Trainer.create_loss_message(
                        eval_loss_dict_cnet,
                        expr_ID=self.cfg.expr_ID,
                        epoch_num=self.epochs_completed,
                        it=len(self.ds_val),
                        model_name='CoarseNet',
                        try_num=self.try_num,
                        mode='evald')
                    if eval_loss_dict_cnet['loss_total'] < self.best_loss_cnet:

                        self.cfg.best_cnet = makepath(os.path.join(
                            self.cfg.work_dir, 'snapshots',
                            'TR%02d_E%03d_cnet.pt' %
                            (self.try_num, self.epochs_completed)),
                                                      isfile=True)
                        self.save_cnet()
                        self.logger(eval_msg + ' ** ')
                        self.best_loss_cnet = eval_loss_dict_cnet['loss_total']

                    else:
                        self.logger(eval_msg)

                    self.swriter.add_scalars(
                        'total_loss_cnet/scalars', {
                            'train_loss_total':
                            train_loss_dict_cnet['loss_total'],
                            'evald_loss_total':
                            eval_loss_dict_cnet['loss_total'],
                        }, self.epochs_completed)

                if early_stopping_cnet(eval_loss_dict_cnet['loss_total']):
                    self.fit_cnet = False
                    self.logger('Early stopping CoarseNet training!')

            if self.fit_rnet:

                lr_scheduler_rnet.step(eval_loss_dict_rnet['loss_total'])
                cur_lr_rnet = self.optimizer_rnet.param_groups[0]['lr']

                if cur_lr_rnet != prev_lr_rnet:
                    self.logger(
                        '--- RefineNet learning rate changed from %.2e to %.2e ---'
                        % (prev_lr_rnet, cur_lr_rnet))
                    prev_lr_rnet = cur_lr_rnet

                with torch.no_grad():
                    eval_msg = Trainer.create_loss_message(
                        eval_loss_dict_rnet,
                        expr_ID=self.cfg.expr_ID,
                        epoch_num=self.epochs_completed,
                        it=len(self.ds_val),
                        model_name='RefineNet',
                        try_num=self.try_num,
                        mode='evald')
                    if eval_loss_dict_rnet['loss_total'] < self.best_loss_rnet:

                        self.cfg.best_rnet = makepath(os.path.join(
                            self.cfg.work_dir, 'snapshots',
                            'TR%02d_E%03d_rnet.pt' %
                            (self.try_num, self.epochs_completed)),
                                                      isfile=True)
                        self.save_rnet()
                        self.logger(eval_msg + ' ** ')
                        self.best_loss_rnet = eval_loss_dict_rnet['loss_total']

                    else:
                        self.logger(eval_msg)

                    self.swriter.add_scalars(
                        'total_loss_rnet/scalars', {
                            'train_loss_total':
                            train_loss_dict_rnet['loss_total'],
                            'evald_loss_total':
                            eval_loss_dict_rnet['loss_total'],
                        }, self.epochs_completed)

                if early_stopping_rnet(eval_loss_dict_rnet['loss_total']):
                    self.fit_rnet = False
                    self.logger('Early stopping RefineNet training!')

            self.epochs_completed += 1

            if not self.fit_cnet and not self.refine_net:
                self.logger('Stopping the training!')

        endtime = datetime.now().replace(microsecond=0)

        self.logger('Finished Training at %s\n' %
                    (datetime.strftime(endtime, '%Y-%m-%d_%H:%M:%S')))
        self.logger(
            'Training done in %s! Best CoarseNet val total loss achieved: %.2e\n'
            % (endtime - starttime, self.best_loss_cnet))
        self.logger('Best CoarseNet model path: %s\n' % self.cfg.best_cnet)

        self.logger('Best RefineNet val total loss achieved: %.2e\n' %
                    (self.best_loss_rnet))
        self.logger('Best RefineNet model path: %s\n' % self.cfg.best_rnet)
Beispiel #6
0
    def __init__(self, cfg, inference=False, evaluate=False):

        self.dtype = torch.float32

        torch.manual_seed(cfg.seed)

        starttime = datetime.now().replace(microsecond=0)
        makepath(cfg.work_dir, isfile=False)
        logger = makelogger(
            makepath(os.path.join(cfg.work_dir, '%s.log' % (cfg.expr_ID)),
                     isfile=True)).info
        self.logger = logger

        summary_logdir = os.path.join(cfg.work_dir, 'summaries')
        self.swriter = SummaryWriter(log_dir=summary_logdir)
        logger('[%s] - Started training GrabNet, experiment code %s' %
               (cfg.expr_ID, starttime))
        logger('tensorboard --logdir=%s' % summary_logdir)
        logger('Torch Version: %s\n' % torch.__version__)
        logger('Base dataset_dir is %s' % cfg.dataset_dir)

        # shutil.copy2(os.path.basename(sys.argv[0]), cfg.work_dir)

        use_cuda = torch.cuda.is_available()
        if use_cuda:
            torch.cuda.empty_cache()
        self.device = torch.device(
            "cuda:%d" % cfg.cuda_id if torch.cuda.is_available() else "cpu")

        gpu_brand = torch.cuda.get_device_name(
            cfg.cuda_id) if use_cuda else None
        gpu_count = torch.cuda.device_count() if cfg.use_multigpu else 1
        if use_cuda:
            logger('Using %d CUDA cores [%s] for training!' %
                   (gpu_count, gpu_brand))

        self.data_info = {}
        self.load_data(cfg, inference)

        with torch.no_grad():
            self.rhm_train = mano.load(model_path=cfg.rhm_path,
                                       model_type='mano',
                                       num_pca_comps=45,
                                       batch_size=cfg.batch_size // gpu_count,
                                       flat_hand_mean=True).to(self.device)

        self.coarse_net = CoarseNet().to(self.device)
        self.refine_net = RefineNet().to(self.device)

        self.LossL1 = torch.nn.L1Loss(reduction='mean')
        self.LossL2 = torch.nn.MSELoss(reduction='mean')

        if cfg.use_multigpu:
            self.coarse_net = nn.DataParallel(self.coarse_net)
            self.refine_net = nn.DataParallel(self.refine_net)
            logger("Training on Multiple GPU's")

        vars_cnet = [var[1] for var in self.coarse_net.named_parameters()]
        vars_rnet = [var[1] for var in self.refine_net.named_parameters()]

        cnet_n_params = sum(p.numel() for p in vars_cnet if p.requires_grad)
        rnet_n_params = sum(p.numel() for p in vars_rnet if p.requires_grad)
        logger('Total Trainable Parameters for CoarseNet is %2.2f M.' %
               ((cnet_n_params) * 1e-6))
        logger('Total Trainable Parameters for RefineNet is %2.2f M.' %
               ((rnet_n_params) * 1e-6))

        self.optimizer_cnet = optim.Adam(vars_cnet,
                                         lr=cfg.base_lr,
                                         weight_decay=cfg.reg_coef)
        self.optimizer_rnet = optim.Adam(vars_rnet,
                                         lr=cfg.base_lr,
                                         weight_decay=cfg.reg_coef)

        self.best_loss_cnet = np.inf
        self.best_loss_rnet = np.inf

        self.try_num = cfg.try_num
        self.epochs_completed = 0
        self.cfg = cfg
        self.coarse_net.cfg = cfg

        if cfg.best_cnet is not None:
            self._get_cnet_model().load_state_dict(torch.load(
                cfg.best_cnet, map_location=self.device),
                                                   strict=False)
            logger('Restored CoarseNet model from %s' % cfg.best_cnet)
        if cfg.best_rnet is not None:
            self._get_rnet_model().load_state_dict(torch.load(
                cfg.best_rnet, map_location=self.device),
                                                   strict=False)
            logger('Restored RefineNet model from %s' % cfg.best_rnet)

        # weights for contact, penetration and distance losses
        self.vpe = torch.from_numpy(np.load(cfg.vpe_path)).to(self.device).to(
            torch.long)
        rh_f = torch.from_numpy(self.rhm_train.faces.astype(np.int32)).view(
            1, -1, 3)
        self.rh_f = rh_f.repeat(self.cfg.batch_size, 1,
                                1).to(self.device).to(torch.long)

        v_weights = torch.from_numpy(np.load(cfg.c_weights_path)).to(
            torch.float32).to(self.device)
        v_weights2 = torch.pow(v_weights, 1.0 / 2.5)
        self.refine_net.v_weights = v_weights
        self.refine_net.v_weights2 = v_weights2
        self.refine_net.rhm_train = self.rhm_train

        self.v_weights = v_weights
        self.v_weights2 = v_weights2

        self.w_dist = torch.ones([self.cfg.batch_size,
                                  self.n_obj_verts]).to(self.device)
        self.contact_v = v_weights > 0.8
Beispiel #7
0
def vis_results(ho,
                dorig,
                coarse_net,
                refine_net,
                rh_model,
                save=False,
                save_dir=None,
                rh_model_pkl=None,
                vis=True):

    # with torch.no_grad():
    imw, imh = 1920, 780
    cols = len(dorig['bps_object'])
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')

    if vis:
        mvs = MeshViewers(window_width=imw,
                          window_height=imh,
                          shape=[1, cols],
                          keepalive=True)

    # drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
    #
    # for k in drec_cnet.keys():
    #     print('drec cnet', k, drec_cnet[k].shape)

    # verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

    drec_cnet = {}

    hand_pose_in = torch.Tensor(ho.hand_pose[3:]).unsqueeze(0)
    mano_out_1 = rh_model_pkl(hand_pose=hand_pose_in)
    hand_pose_in = mano_out_1.hand_pose

    mTc = torch.Tensor(ho.hand_mTc)
    approx_global_orient = rotmat2aa(mTc[:3, :3].unsqueeze(0))

    if torch.isnan(approx_global_orient).any():  # Using honnotate?
        approx_global_orient = torch.Tensor(ho.hand_pose[:3]).unsqueeze(0)

    approx_global_orient = approx_global_orient.squeeze(1).squeeze(1)
    approx_trans = mTc[:3, 3].unsqueeze(0)

    target_verts = torch.Tensor(ho.hand_verts).unsqueeze(0)

    pose, trans, rot = util.opt_hand(rh_model, target_verts, hand_pose_in,
                                     approx_trans, approx_global_orient)

    # drec_cnet['hand_pose'] = torch.einsum('bi,ij->bj', [hand_pose_in, rh_model_pkl.hand_components])
    drec_cnet['transl'] = trans
    drec_cnet['global_orient'] = rot
    drec_cnet['hand_pose'] = pose

    verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

    _, h2o, _ = point2point_signed(verts_rh_gen_cnet,
                                   dorig['verts_object'].to(device))

    drec_cnet['trans_rhand_f'] = drec_cnet['transl']
    drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(
        drec_cnet['global_orient']).view(-1, 3, 3)
    drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(
        -1, 15, 3, 3)
    drec_cnet['verts_object'] = dorig['verts_object'].to(device)
    drec_cnet['h2o_dist'] = h2o.abs()

    print(
        'Hand fitting err',
        np.linalg.norm(
            verts_rh_gen_cnet.squeeze().detach().numpy() - ho.hand_verts, 2,
            1).mean())
    orig_obj = dorig['mesh_object'][0].v
    # print(orig_obj.shape, orig_obj)
    # print('Obj fitting err', np.linalg.norm(orig_obj - ho.obj_verts, 2, 1).mean())

    drec_rnet = refine_net(**drec_cnet)
    mano_out = rh_model(**drec_rnet)
    verts_rh_gen_rnet = mano_out.vertices
    joints_out = mano_out.joints

    if vis:
        for cId in range(0, len(dorig['bps_object'])):
            try:
                from copy import deepcopy
                meshes = deepcopy(dorig['mesh_object'])
                obj_mesh = meshes[cId]
            except:
                obj_mesh = points_to_spheres(to_cpu(
                    dorig['verts_object'][cId]),
                                             radius=0.002,
                                             vc=name_to_rgb['green'])

            hand_mesh_gen_cnet = Mesh(v=to_cpu(verts_rh_gen_cnet[cId]),
                                      f=rh_model.faces,
                                      vc=name_to_rgb['pink'])
            hand_mesh_gen_rnet = Mesh(v=to_cpu(verts_rh_gen_rnet[cId]),
                                      f=rh_model.faces,
                                      vc=name_to_rgb['gray'])

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_cnet.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)
                # print('rotmat', rotmat)

            hand_mesh_gen_cnet.reset_face_normals()
            hand_mesh_gen_rnet.reset_face_normals()

            # mvs[0][cId].set_static_meshes([hand_mesh_gen_cnet] + obj_mesh, blocking=True)
            # mvs[0][cId].set_static_meshes([hand_mesh_gen_rnet,obj_mesh], blocking=True)
            mvs[0][cId].set_static_meshes(
                [hand_mesh_gen_rnet, hand_mesh_gen_cnet, obj_mesh],
                blocking=True)

            if save:
                save_path = os.path.join(save_dir, str(cId))
                makepath(save_path)
                hand_mesh_gen_rnet.write_ply(filename=save_path +
                                             '/rh_mesh_gen_%d.ply' % cId)
                obj_mesh[0].write_ply(filename=save_path +
                                      '/obj_mesh_%d.ply' % cId)

    return verts_rh_gen_rnet, joints_out
Beispiel #8
0
def get_meshes(dorig,
               coarse_net,
               refine_net,
               rh_model,
               save=False,
               save_dir=None):
    with torch.no_grad():

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
        output = rh_model(**drec_cnet)
        verts_rh_gen_cnet = output.vertices

        _, h2o, _ = point2point_signed(verts_rh_gen_cnet,
                                       dorig['verts_object'].to(device))

        drec_cnet['trans_rhand_f'] = drec_cnet['transl']
        drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(
            drec_cnet['global_orient']).view(-1, 3, 3)
        drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(
            drec_cnet['hand_pose']).view(-1, 15, 3, 3)
        drec_cnet['verts_object'] = dorig['verts_object'].to(device)
        drec_cnet['h2o_dist'] = h2o.abs()

        drec_rnet = refine_net(**drec_cnet)
        output = rh_model(**drec_rnet)
        print("hand shape {} should be idtenty".format(output.betas))
        verts_rh_gen_rnet = output.vertices

        # Reorder joints to match visualization utilities (joint_mapper) (TODO)
        joints_rh_gen_rnet = output.joints  # [:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
        transforms_rh_gen_rnet = output.transforms  # [:, [0, 13, 14, 15, 1, 2, 3, 4, 5, 6, 10, 11, 12, 7, 8, 9]]
        joints_rh_gen_rnet = to_cpu(joints_rh_gen_rnet)
        transforms_rh_gen_rnet = to_cpu(transforms_rh_gen_rnet)

        gen_meshes = []
        for cId in range(0, len(dorig['bps_object'])):
            try:
                obj_mesh = dorig['mesh_object'][cId]
            except:
                obj_mesh = points2sphere(points=to_cpu(
                    dorig['verts_object'][cId]),
                                         radius=0.002,
                                         vc=[145, 191, 219])

            hand_mesh_gen_rnet = Mesh(vertices=to_cpu(verts_rh_gen_rnet[cId]),
                                      faces=rh_model.faces,
                                      vc=[145, 191, 219])
            hand_joint_gen_rnet = joints_rh_gen_rnet[cId]
            hand_transform_gen_rnet = transforms_rh_gen_rnet[cId]

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)

                hand_joint_gen_rnet = hand_joint_gen_rnet @ rotmat.T
                hand_transform_gen_rnet[:, :, :3, :3] = np.matmul(
                    rotmat[None, ...], hand_transform_gen_rnet[:, :, :3, :3])

            gen_meshes.append([obj_mesh, hand_mesh_gen_rnet])
            if save:
                makepath(save_dir)
                print("saving dir {}".format(save_dir))
                np.save(save_dir + '/joints_%d.npy' % cId, hand_joint_gen_rnet)
                np.save(save_dir + '/trans_%d.npy' % cId,
                        hand_transform_gen_rnet)

        return gen_meshes