Beispiel #1
0
def fit_SMPLD(scans, smpl_pkl=None, gender='male', save_path=None, display=False):
    # Get SMPL faces
    sp = SmplPaths(gender=gender)
    smpl_faces = sp.get_faces()
    th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).cuda()

    # Batch size
    batch_sz = len(scans)

    # Init SMPL
    if smpl_pkl is None or smpl_pkl[0] is None:
        print('SMPL not specified, fitting SMPL now')
        pose, betas, trans = fit_SMPL(scans, None, gender, save_path, display)
    else:
        pose, betas, trans = [], [], []
        for spkl in smpl_pkl:
            smpl_dict = pkl.load(open(spkl, 'rb'), encoding='latin-1')
            p, b, t = smpl_dict['pose'], smpl_dict['betas'], smpl_dict['trans']
            pose.append(p)
            if len(b) == 10:
                temp = np.zeros((300,))
                temp[:10] = b
                b = temp.astype('float32')
            betas.append(b)
            trans.append(t)
        pose, betas, trans = np.array(pose), np.array(betas), np.array(trans)

    betas, pose, trans = torch.tensor(betas), torch.tensor(pose), torch.tensor(trans)
    smpl = th_batch_SMPL(batch_sz, betas, pose, trans, faces=th_faces).cuda()

    verts, _, _, _ = smpl()
    init_smpl_meshes = [tm.from_tensors(vertices=v.clone().detach(),
                                        faces=smpl.faces) for v in verts]

    # Load scans
    th_scan_meshes = []
    for scan in scans:
        th_scan = tm.from_obj(scan)
        if save_path is not None:
            th_scan.save_mesh(join(save_path, split(scan)[1]))
        th_scan.vertices = th_scan.vertices.cuda()
        th_scan.faces = th_scan.faces.cuda()
        th_scan.vertices.requires_grad = False
        th_scan_meshes.append(th_scan)

    # Optimize
    optimize_offsets(th_scan_meshes, smpl, init_smpl_meshes, 5, 10)
    print('Done')

    verts, _, _, _ = smpl()
    th_smpl_meshes = [tm.from_tensors(vertices=v,
                                      faces=smpl.faces) for v in verts]

    if save_path is not None:
        if not exists(save_path):
            os.makedirs(save_path)

        names = [split(s)[1] for s in scans]

        # Save meshes
        save_meshes(th_smpl_meshes, [join(save_path, n.replace('.obj', '_smpld.obj')) for n in names])
        save_meshes(th_scan_meshes, [join(save_path, n) for n in names])
        # Save params
        for p, b, t, d, n in zip(smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(),
                                 smpl.trans.cpu().detach().numpy(), smpl.offsets.cpu().detach().numpy(), names):
            smpl_dict = {'pose': p, 'betas': b, 'trans': t, 'offsets': d}
            pkl.dump(smpl_dict, open(join(save_path, n.replace('.obj', '_smpld.pkl')), 'wb'))

    return smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(), \
           smpl.trans.cpu().detach().numpy(), smpl.offsets.cpu().detach().numpy()
Beispiel #2
0
def fit_SMPLX(scans,
              pose_files=None,
              gender='male',
              save_path=None,
              display=None):
    """
    :param save_path:
    :param scans: list of scan paths
    :param pose_files:
    :return:
    """
    # Get SMPL faces
    sp = SmplPaths(gender=gender)
    smpl_faces = sp.get_faces()
    th_faces = torch.tensor(smpl_faces.astype('float32'),
                            dtype=torch.long).cuda()

    # Batch size
    batch_sz = len(scans)

    # Set optimization hyper parameters
    iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30

    if False:
        """Test by loading GT SMPL params"""
        betas, pose, trans = torch.tensor(
            GT_SMPL['betas'].astype('float32')).unsqueeze(0), torch.tensor(
                GT_SMPL['pose'].astype('float32')).unsqueeze(0), torch.zeros(
                    (batch_sz, 3))
    else:
        prior = get_prior(gender=gender)
        pose_init = torch.zeros((batch_sz, 72))
        pose_init[:, 3:] = prior.mean
        betas, pose, trans = torch.zeros(
            (batch_sz, 300)), pose_init, torch.zeros((batch_sz, 3))

    # Init SMPL, pose with mean smpl pose, as in ch.registration
    smpl = th_batch_SMPL(batch_sz, betas, pose, trans, faces=th_faces).cuda()

    # Load scans and center them. Once smpl is registered, move it accordingly.
    # Do not forget to change the location of 3D joints/ landmarks accordingly.
    th_scan_meshes, centers = [], []
    for scan in scans:
        print('scan path ...', scan)
        th_scan = tm.from_obj(scan)
        # cent = th_scan.vertices.mean(axis=0)
        # centers.append(cent)
        # th_scan.vertices -= cent
        th_scan.vertices = th_scan.vertices.cuda()
        th_scan.faces = th_scan.faces.cuda()
        th_scan.vertices.requires_grad = False
        th_scan.cuda()
        th_scan_meshes.append(th_scan)

    # Load pose information if pose file is given
    # Bharat: Shouldn't we structure th_pose_3d as [key][batch, ...] as opposed to current [batch][key]? See batch_get_pose_obj() in body_objectives.py
    th_pose_3d = None
    if pose_files is not None:
        th_no_right_hand_visible, th_no_left_hand_visible, th_pose_3d = [], [], []
        for pose_file in pose_files:
            with open(pose_file) as f:
                pose_3d = json.load(f)
                th_no_right_hand_visible.append(
                    np.max(
                        np.array(pose_3d['hand_right_keypoints_3d']).reshape(
                            -1, 4)[:, 3]) < HAND_VISIBLE)
                th_no_left_hand_visible.append(
                    np.max(
                        np.array(pose_3d['hand_left_keypoints_3d']).reshape(
                            -1, 4)[:, 3]) < HAND_VISIBLE)

                pose_3d['pose_keypoints_3d'] = torch.from_numpy(
                    np.array(pose_3d['pose_keypoints_3d']).astype(
                        np.float32).reshape(-1, 4))
                pose_3d['face_keypoints_3d'] = torch.from_numpy(
                    np.array(pose_3d['face_keypoints_3d']).astype(
                        np.float32).reshape(-1, 4))
                pose_3d['hand_right_keypoints_3d'] = torch.from_numpy(
                    np.array(pose_3d['hand_right_keypoints_3d']).astype(
                        np.float32).reshape(-1, 4))
                pose_3d['hand_left_keypoints_3d'] = torch.from_numpy(
                    np.array(pose_3d['hand_left_keypoints_3d']).astype(
                        np.float32).reshape(-1, 4))
            th_pose_3d.append(pose_3d)

        prior_weight = get_prior_weight(th_no_right_hand_visible,
                                        th_no_left_hand_visible).cuda()

        # Optimize pose first
        optimize_pose_only(th_scan_meshes,
                           smpl,
                           pose_iterations,
                           pose_steps_per_iter,
                           th_pose_3d,
                           prior_weight,
                           display=None if display is None else 0)

    # Optimize pose and shape
    optimize_pose_shape(th_scan_meshes,
                        smpl,
                        iterations,
                        steps_per_iter,
                        th_pose_3d,
                        display=None if display is None else 0)

    verts, _, _, _ = smpl()
    th_smpl_meshes = [
        tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts
    ]

    if save_path is not None:
        if not exists(save_path):
            os.makedirs(save_path)

        names = [split(s)[1] for s in scans]

        # Save meshes
        save_meshes(
            th_smpl_meshes,
            [join(save_path, n.replace('.obj', '_smpl.obj')) for n in names])
        save_meshes(th_scan_meshes, [join(save_path, n) for n in names])

        # Save params
        for p, b, t, n in zip(smpl.pose.cpu().detach().numpy(),
                              smpl.betas.cpu().detach().numpy(),
                              smpl.trans.cpu().detach().numpy(), names):
            smpl_dict = {'pose': p, 'betas': b, 'trans': t}
            pkl.dump(
                smpl_dict,
                open(join(save_path, n.replace('.obj', '_smpl.pkl')), 'wb'))

        return smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach(
        ).numpy(), smpl.trans.cpu().detach().numpy()
Beispiel #3
0
def fit_SMPL(scans, scan_labels, gender='male', save_path=None, scale_file=None, display=None):
    """
    :param save_path:
    :param scans: list of scan paths
    :param pose_files:
    :return:
    """
    # Get SMPL faces
    sp = SmplPaths(gender=gender)
    smpl_faces = sp.get_faces()
    th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).to(DEVICE)

    # Load SMPL parts
    part_labels = pkl.load(open('/BS/bharat-3/work/IPNet/assets/smpl_parts_dense.pkl', 'rb'))
    labels = np.zeros((6890,), dtype='int32')
    for n, k in enumerate(part_labels):
        labels[part_labels[k]] = n
    labels = torch.tensor(labels).unsqueeze(0).to(DEVICE)

    # Load scan parts
    scan_part_labels = []
    for sc_l in scan_labels:
        temp = torch.tensor(np.load(sc_l).astype('int32')).to(DEVICE)
        scan_part_labels.append(temp)

    # Batch size
    batch_sz = len(scans)

    # Set optimization hyper parameters
    iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30

    prior = get_prior(gender=gender, precomputed=True)
    pose_init = torch.zeros((batch_sz, 72))
    pose_init[:, 3:] = prior.mean
    betas, pose, trans = torch.zeros((batch_sz, 300)), pose_init, torch.zeros((batch_sz, 3))

    # Init SMPL, pose with mean smpl pose, as in ch.registration
    smpl = th_batch_SMPL(batch_sz, betas, pose, trans, faces=th_faces).to(DEVICE)
    smpl_part_labels = torch.cat([labels] * batch_sz, axis=0)

    th_scan_meshes, centers = [], []
    for scan in scans:
        print('scan path ...', scan)
        temp = Mesh(filename=scan)
        th_scan = tm.from_tensors(torch.tensor(temp.v.astype('float32'), requires_grad=False, device=DEVICE),
                                  torch.tensor(temp.f.astype('int32'), requires_grad=False, device=DEVICE).long())
        th_scan_meshes.append(th_scan)

    if scale_file is not None:
        for n, sc in enumerate(scale_file):
            dat = np.load(sc, allow_pickle=True)
            th_scan_meshes[n].vertices += torch.tensor(dat[1]).to(DEVICE)
            th_scan_meshes[n].vertices *= torch.tensor(dat[0]).to(DEVICE)

    # Optimize pose first
    optimize_pose_only(th_scan_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels, smpl_part_labels,
                       display=None if display is None else 0)

    # Optimize pose and shape
    optimize_pose_shape(th_scan_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels,
                        display=None if display is None else 0)

    verts, _, _, _ = smpl()
    th_smpl_meshes = [tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts]

    if save_path is not None:
        if not exists(save_path):
            os.makedirs(save_path)

        names = [split(s)[1] for s in scans]

        # Save meshes
        save_meshes(th_smpl_meshes, [join(save_path, n.replace('.ply', '_smpl.obj')) for n in names])
        save_meshes(th_scan_meshes, [join(save_path, n) for n in names])

        # Save params
        for p, b, t, n in zip(smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(),
                              smpl.trans.cpu().detach().numpy(), names):
            smpl_dict = {'pose': p, 'betas': b, 'trans': t}
            pkl.dump(smpl_dict, open(join(save_path, n.replace('.ply', '_smpl.pkl')), 'wb'))

        return smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(), smpl.trans.cpu().detach().numpy()
Beispiel #4
0
def fit_SMPLD(scans, smpl_pkl, gender='male', save_path=None, scale_file=None):
    # Get SMPL faces
    sp = SmplPaths(gender=gender)
    smpl_faces = sp.get_faces()
    th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).cuda()

    # Batch size
    batch_sz = len(scans)

    # Init SMPL
    pose, betas, trans = [], [], []
    for spkl in smpl_pkl:
        smpl_dict = pkl.load(open(spkl, 'rb'))
        p, b, t = smpl_dict['pose'], smpl_dict['betas'], smpl_dict['trans']
        pose.append(p)
        if len(b) == 10:
            temp = np.zeros((300,))
            temp[:10] = b
            b = temp.astype('float32')
        betas.append(b)
        trans.append(t)
    pose, betas, trans = np.array(pose), np.array(betas), np.array(trans)

    betas, pose, trans = torch.tensor(betas), torch.tensor(pose), torch.tensor(trans)
    smpl = th_batch_SMPL(batch_sz, betas, pose, trans, faces=th_faces).cuda()

    verts, _, _, _ = smpl()
    init_smpl_meshes = [tm.from_tensors(vertices=v.clone().detach(),
                                        faces=smpl.faces) for v in verts]

    # Load scans
    th_scan_meshes = []
    for scan in scans:
        print('scan path ...', scan)
        temp = Mesh(filename=scan)
        th_scan = tm.from_tensors(torch.tensor(temp.v.astype('float32'), requires_grad=False, device=DEVICE),
                                  torch.tensor(temp.f.astype('int32'), requires_grad=False, device=DEVICE).long())
        th_scan_meshes.append(th_scan)

    if scale_file is not None:
        for n, sc in enumerate(scale_file):
            dat = np.load(sc, allow_pickle=True)
            th_scan_meshes[n].vertices += torch.tensor(dat[1]).to(DEVICE)
            th_scan_meshes[n].vertices *= torch.tensor(dat[0]).to(DEVICE)

    # Optimize
    optimize_offsets(th_scan_meshes, smpl, init_smpl_meshes, 5, 10)
    print('Done')

    verts, _, _, _ = smpl()
    th_smpl_meshes = [tm.from_tensors(vertices=v,
                                      faces=smpl.faces) for v in verts]

    if save_path is not None:
        if not exists(save_path):
            os.makedirs(save_path)

        names = [split(s)[1] for s in scans]

        # Save meshes
        save_meshes(th_smpl_meshes, [join(save_path, n.replace('.ply', '_smpld.obj')) for n in names])
        save_meshes(th_scan_meshes, [join(save_path, n) for n in names])
        # Save params
        for p, b, t, d, n in zip(smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(),
                                 smpl.trans.cpu().detach().numpy(), smpl.offsets.cpu().detach().numpy(), names):
            smpl_dict = {'pose': p, 'betas': b, 'trans': t, 'offsets': d}
            pkl.dump(smpl_dict, open(join(save_path, n.replace('.ply', '_smpld.pkl')), 'wb'))

    return smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(), \
           smpl.trans.cpu().detach().numpy(), smpl.offsets.cpu().detach().numpy()
Beispiel #5
0
def SMPLD_register(args):
    cfg = config.load_config(args.config, 'configs/default.yaml')
    out_dir = cfg['training']['out_dir']
    generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
    is_cuda = (torch.cuda.is_available() and not args.no_cuda)
    device = torch.device("cuda" if is_cuda else "cpu")

    if args.subject_idx >= 0 and args.sequence_idx >= 0:
        logger, _ = create_logger(generation_dir, phase='reg_subject{}_sequence{}'.format(args.subject_idx, args.sequence_idx), create_tf_logs=False)
    else:
        logger, _ = create_logger(generation_dir, phase='reg_all', create_tf_logs=False)

    # Get dataset
    if args.subject_idx >= 0 and args.sequence_idx >= 0:
        dataset = config.get_dataset('test', cfg, sequence_idx=args.sequence_idx, subject_idx=args.subject_idx)
    else:
        dataset = config.get_dataset('test', cfg)

    batch_size = cfg['generation']['batch_size']

    # Loader
    test_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=1, shuffle=False)

    model_counter = defaultdict(int)

    # Set optimization hyper parameters
    iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30

    inner_dists = []
    outer_dists = []

    for it, data in enumerate(tqdm(test_loader)):
        idxs = data['idx'].cpu().numpy()
        loc = data['points.loc'].cpu().numpy()
        batch_size = idxs.shape[0]
        # Directories to load corresponding informations
        mesh_dir = os.path.join(generation_dir, 'meshes')   # directory for posed and (optionally) unposed implicit outer/inner meshes
        label_dir = os.path.join(generation_dir, 'labels')   # directory for part labels
        register_dir = os.path.join(generation_dir, 'registrations')   # directory for part labels

        if args.use_raw_scan:
            scan_dir = dataset.dataset_folder   # this is the folder that contains CAPE raw scans
        else:
            scan_dir = None

        all_posed_minimal_meshes = []
        all_posed_cloth_meshes = []
        all_posed_vertices = []
        all_unposed_vertices = []
        scan_part_labels = []

        for idx in idxs:
            model_dict = dataset.get_model_dict(idx)

            subset = model_dict['subset']
            subject = model_dict['subject']
            sequence = model_dict['sequence']
            gender = model_dict['gender']
            filebase = os.path.basename(model_dict['data_path'])[:-4]

            folder_name = os.path.join(subset, subject, sequence)
            # TODO: we assume batch size stays the same if one resumes the job
            # can be more flexible to support different batch sizes before and
            # after resume
            register_file = os.path.join(register_dir, folder_name, filebase + 'minimal.registered.ply')
            if os.path.exists(register_file):
                # batch already computed, break
                break

            # points_dict = np.load(model_dict['data_path'])
            # gender = str(points_dict['gender'])

            mesh_dir_ = os.path.join(mesh_dir, folder_name)
            label_dir_ = os.path.join(label_dir, folder_name)

            if scan_dir is not None:
                scan_dir_ = os.path.join(scan_dir, subject, sequence)

            # Load part labels and vertex translations
            label_file_name = filebase + '.minimal.npz'
            label_dict = dict(np.load(os.path.join(label_dir_, label_file_name)))
            labels = torch.tensor(label_dict['part_labels'].astype(np.int64)).to(device)   # part labels for each vertex (14 or 24)
            scan_part_labels.append(labels)

            # Load minimal implicit surfaces
            mesh_file_name = filebase + '.minimal.posed.ply'
            # posed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name))
            posed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False)
            posed_vertices = np.array(posed_mesh.vertices)
            all_posed_vertices.append(posed_vertices)

            posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32'), requires_grad=False, device=device),
                    torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device))
            all_posed_minimal_meshes.append(posed_mesh)

            mesh_file_name = filebase + '.minimal.unposed.ply'
            if os.path.exists(os.path.join(mesh_dir_, mesh_file_name)) and args.init_pose:
                # unposed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name))
                unposed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False)
                unposed_vertices = np.array(unposed_mesh.vertices)
                all_unposed_vertices.append(unposed_vertices)

            if args.use_raw_scan:
                # Load raw scans
                mesh_file_name = filebase + '.ply'
                # posed_mesh = Mesh(filename=os.path.join(scan_dir_, mesh_file_name))
                posed_mesh = trimesh.load(os.path.join(scan_dir_, mesh_file_name), process=False)

                posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32') / 1000, requires_grad=False, device=device),
                        torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device))
                all_posed_cloth_meshes.append(posed_mesh)
            else:
                # Load clothed implicit surfaces
                mesh_file_name = filebase + '.cloth.posed.ply'
                # posed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name))
                posed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False)

                posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32'), requires_grad=False, device=device),
                        torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device))
                all_posed_cloth_meshes.append(posed_mesh)

        if args.num_joints == 24:
            bm = BodyModel(bm_path='body_models/smpl/male/model.pkl', num_betas=10, batch_size=batch_size).to(device)
            parents = bm.kintree_table[0].detach().cpu().numpy()
            labels = bm.weights.argmax(1)
            # Convert 24 parts to 14 parts
            smpl2ipnet = torch.from_numpy(SMPL2IPNET_IDX).to(device)
            labels = smpl2ipnet[labels].clone().unsqueeze(0)
            del bm
        elif args.num_joints == 14:
            with open('body_models/misc/smpl_parts_dense.pkl', 'rb') as f:
                part_labels = pkl.load(f)

            labels = np.zeros((6890,), dtype=np.int64)
            for n, k in enumerate(part_labels):
                labels[part_labels[k]] = n
            labels = torch.tensor(labels).to(device).unsqueeze(0)
        else:
            raise ValueError('Got {} joints but umber of joints can only be either 14 or 24'.format(args.num_joints))

        th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).to(device)

        # We assume loaded meshes are properly scaled and offsetted to the orignal SMPL space,
        if len(all_posed_minimal_meshes) > 0 and len(all_unposed_vertices) == 0:
            # IPNet optimization without vertex traslation
            # raise NotImplementedError('Optimization for IPNet is not implemented yet.')
            if args.num_joints == 24:
                for idx in range(len(scan_part_labels)):
                    scan_part_labels[idx] = smpl2ipnet[scan_part_labels[idx]].clone()

            prior = get_prior(gender=gender, precomputed=True)
            pose_init = torch.zeros((batch_size, 72))
            pose_init[:, 3:] = prior.mean
            betas, pose, trans = torch.zeros((batch_size, 10)), pose_init, torch.zeros((batch_size, 3))

            # Init SMPL, pose with mean smpl pose, as in ch.registration
            smpl = th_batch_SMPL(batch_size, betas, pose, trans, faces=th_faces, gender=gender).to(device)
            smpl_part_labels = torch.cat([labels] * batch_size, axis=0)

            # Optimize pose first
            optimize_pose_only(all_posed_minimal_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels,
                               smpl_part_labels, None, args)

            # Optimize pose and shape
            optimize_pose_shape(all_posed_minimal_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels,
                                None, args)

            inner_vertices, _, _, _ = smpl()

            # Optimize vertices for SMPLD
            init_smpl_meshes = [tm.from_tensors(vertices=v.clone().detach(),
                                                faces=smpl.faces) for v in inner_vertices]
            optimize_offsets(all_posed_cloth_meshes, smpl, init_smpl_meshes, 5, 10, args)

            outer_vertices, _, _, _ = smpl()
        elif len(all_posed_minimal_meshes) > 0:
            # NASA+PTFs optimization with vertex traslations
            # Compute poses from implicit surfaces and correspondences
            # TODO: we could also compute bone-lengths if we train PTFs to predict A-pose with a global translation
            # that equals to the centroid of the pointcloud
            poses = compute_poses(all_posed_vertices, all_unposed_vertices, scan_part_labels, parents, args)
            # Convert 24 parts to 14 parts
            for idx in range(len(scan_part_labels)):
                scan_part_labels[idx] = smpl2ipnet[scan_part_labels[idx]].clone()

            pose_init = torch.from_numpy(poses).float()
            betas, pose, trans = torch.zeros((batch_size, 10)), pose_init, torch.zeros((batch_size, 3))

            # Init SMPL, pose with mean smpl pose, as in ch.registration
            smpl = th_batch_SMPL(batch_size, betas, pose, trans, faces=th_faces, gender=gender).to(device)
            smpl_part_labels = torch.cat([labels] * batch_size, axis=0)

            # Optimize pose first
            optimize_pose_only(all_posed_minimal_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels,
                               smpl_part_labels, None, args)

            # Optimize pose and shape
            optimize_pose_shape(all_posed_minimal_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels,
                                None, args)

            inner_vertices, _, _, _ = smpl()

            # Optimize vertices for SMPLD
            init_smpl_meshes = [tm.from_tensors(vertices=v.clone().detach(),
                                                faces=smpl.faces) for v in inner_vertices]
            optimize_offsets(all_posed_cloth_meshes, smpl, init_smpl_meshes, 5, 10, args)

            outer_vertices, _, _, _ = smpl()
        else:
            inner_vertices = outer_vertices = None

        if args.use_raw_scan:
            for i, idx in enumerate(idxs):
                model_dict = dataset.get_model_dict(idx)

                subset = model_dict['subset']
                subject = model_dict['subject']
                sequence = model_dict['sequence']
                filebase = os.path.basename(model_dict['data_path'])[:-4]

                folder_name = os.path.join(subset, subject, sequence)
                register_dir_ = os.path.join(register_dir, folder_name)
                if not os.path.exists(register_dir_):
                    os.makedirs(register_dir_)

                if not os.path.exists(os.path.join(register_dir_, filebase + 'minimal.registered.ply')):
                    registered_mesh = trimesh.Trimesh(inner_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False)
                    registered_mesh.export(os.path.join(register_dir_, filebase + 'minimal.registered.ply'))

                if not os.path.exists(os.path.join(register_dir_, filebase + 'cloth.registered.ply')):
                    registered_mesh = trimesh.Trimesh(outer_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False)
                    registered_mesh.export(os.path.join(register_dir_, filebase + 'cloth.registered.ply'))
        else:
            # Evaluate registered mesh
            gt_smpl_mesh = data['points.minimal_smpl_vertices'].to(device)
            gt_smpld_mesh = data['points.smpl_vertices'].to(device)
            if inner_vertices is None:
                # if vertices are None, we assume they already exist due to previous runs
                inner_vertices = []
                outer_vertices = []
                for i, idx in enumerate(idxs):

                    model_dict = dataset.get_model_dict(idx)

                    subset = model_dict['subset']
                    subject = model_dict['subject']
                    sequence = model_dict['sequence']
                    filebase = os.path.basename(model_dict['data_path'])[:-4]

                    folder_name = os.path.join(subset, subject, sequence)
                    register_dir_ = os.path.join(register_dir, folder_name)

                    # registered_mesh = Mesh(filename=os.path.join(register_dir_, filebase + 'minimal.registered.ply'))
                    registered_mesh = trimesh.load(os.path.join(register_dir_, filebase + 'minimal.registered.ply'), process=False)
                    registered_v = torch.tensor(registered_mesh.vertices.astype(np.float32), requires_grad=False, device=device)
                    inner_vertices.append(registered_v)

                    # registered_mesh = Mesh(filename=os.path.join(register_dir_, filebase + 'cloth.registered.ply'))
                    registered_mesh = trimesh.load(os.path.join(register_dir_, filebase + 'cloth.registered.ply'), process=False)
                    registered_v = torch.tensor(registered_mesh.vertices.astype(np.float32), requires_grad=False, device=device)
                    outer_vertices.append(registered_v)

                inner_vertices = torch.stack(inner_vertices, dim=0)
                outer_vertices = torch.stack(outer_vertices, dim=0)

            inner_dist = torch.norm(gt_smpl_mesh - inner_vertices, dim=2).mean(-1)
            outer_dist = torch.norm(gt_smpld_mesh - outer_vertices, dim=2).mean(-1)

            for i, idx in enumerate(idxs):
                model_dict = dataset.get_model_dict(idx)

                subset = model_dict['subset']
                subject = model_dict['subject']
                sequence = model_dict['sequence']
                filebase = os.path.basename(model_dict['data_path'])[:-4]

                folder_name = os.path.join(subset, subject, sequence)
                register_dir_ = os.path.join(register_dir, folder_name)
                if not os.path.exists(register_dir_):
                    os.makedirs(register_dir_)

                logger.info('Inner distance for input {}: {} cm'.format(filebase, inner_dist[i].item()))
                logger.info('Outer distance for input {}: {} cm'.format(filebase, outer_dist[i].item()))

                if not os.path.exists(os.path.join(register_dir_, filebase + 'minimal.registered.ply')):
                    registered_mesh = trimesh.Trimesh(inner_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False)
                    registered_mesh.export(os.path.join(register_dir_, filebase + 'minimal.registered.ply'))

                if not os.path.exists(os.path.join(register_dir_, filebase + 'cloth.registered.ply')):
                    registered_mesh = trimesh.Trimesh(outer_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False)
                    registered_mesh.export(os.path.join(register_dir_, filebase + 'cloth.registered.ply'))

            inner_dists.extend(inner_dist.detach().cpu().numpy())
            outer_dists.extend(outer_dist.detach().cpu().numpy())

    logger.info('Mean inner distance: {} cm'.format(np.mean(inner_dists)))
    logger.info('Mean outer distance: {} cm'.format(np.mean(outer_dists)))