Ejemplo n.º 1
0
def forward_step(th_scan_meshes, smpl, th_pose_3d=None):
    """
    Performs a forward step, given smpl and scan meshes.
    Then computes the losses.
    """
    # Get pose prior
    prior = get_prior(smpl.gender)

    # forward
    verts, _, _, _ = smpl()
    th_smpl_meshes = [
        tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts
    ]

    # losses
    loss = dict()
    loss['s2m'] = batch_point_to_surface(
        [sm.vertices for sm in th_scan_meshes], th_smpl_meshes)
    loss['m2s'] = batch_point_to_surface(
        [sm.vertices for sm in th_smpl_meshes], th_scan_meshes)
    loss['betas'] = torch.mean(smpl.betas**2, axis=1)
    loss['pose_pr'] = prior(smpl.pose)
    if th_pose_3d is not None:
        loss['pose_obj'] = batch_get_pose_obj(th_pose_3d, smpl)
    return loss
Ejemplo n.º 2
0
def forward_step_pose_only(smpl, th_pose_3d, prior_weight):
    """
    Performs a forward step, given smpl and scan meshes.
    Then computes the losses.
    """
    # Get pose prior
    prior = get_prior(smpl.gender)

    # losses
    loss = dict()
    loss['pose_pr'] = prior(smpl.pose, prior_weight)
    loss['pose_obj'] = batch_get_pose_obj(th_pose_3d, smpl, init_pose=False)
    return loss
Ejemplo n.º 3
0
def forward_step(th_scan_meshes,
                 smplx,
                 scan_part_labels,
                 smplx_part_labels,
                 search_tree=None,
                 pen_distance=None,
                 tri_filtering_module=None):
    """
    Performs a forward step, given smplx and scan meshes.
    Then computes the losses.
    """
    # Get pose prior
    prior = get_prior(smplx.gender, precomputed=True)

    # forward
    # verts, _, _, _ = smplx()
    verts = smplx()
    th_smplx_meshes = [
        tm.from_tensors(vertices=v, faces=smplx.faces) for v in verts
    ]

    scan_verts = [sm.vertices for sm in th_scan_meshes]
    smplx_verts = [sm.vertices for sm in th_smplx_meshes]

    # losses
    loss = dict()
    loss['s2m'] = batch_point_to_surface(scan_verts, th_smplx_meshes)
    loss['m2s'] = batch_point_to_surface(smplx_verts, th_scan_meshes)
    loss['betas'] = torch.mean(smplx.betas**2, axis=1)
    # loss['pose_pr'] = prior(smplx.pose)
    loss['interpenetration'] = interpenetration_loss(verts, smplx.faces,
                                                     search_tree, pen_distance,
                                                     tri_filtering_module, 1.0)
    loss['part'] = []
    for n, (sc_v, sc_l) in enumerate(zip(scan_verts, scan_part_labels)):
        tot = 0
        for i in range(NUM_PARTS):  # we currently use 14 parts
            if i not in sc_l:
                continue
            ind = torch.where(sc_l == i)[0]
            sc_part_points = sc_v[ind].unsqueeze(0)
            sm_part_points = smplx_verts[n][torch.where(
                smplx_part_labels[n] == i)[0]].unsqueeze(0)
            dist = chamfer_distance(sc_part_points,
                                    sm_part_points,
                                    w1=1.,
                                    w2=1.)
            tot += dist
        loss['part'].append(tot / NUM_PARTS)
    loss['part'] = torch.stack(loss['part'])
    return loss
Ejemplo n.º 4
0
def forward_step_SMPL(th_scan_meshes, smpl, scan_part_labels, smpl_part_labels, args):
    """
    Performs a forward step, given smpl and scan meshes.
    Then computes the losses.
    """
    # Get pose prior
    prior = get_prior(smpl.gender, precomputed=True)

    # forward
    verts, _, _, _ = smpl()
    th_smpl_meshes = [tm.from_tensors(vertices=v,
                                      faces=smpl.faces) for v in verts]

    scan_verts = [sm.vertices for sm in th_scan_meshes]
    smpl_verts = [sm.vertices for sm in th_smpl_meshes]

    # losses
    loss = dict()
    loss['s2m'] = batch_point_to_surface(scan_verts, th_smpl_meshes)
    loss['m2s'] = batch_point_to_surface(smpl_verts, th_scan_meshes)
    loss['betas'] = torch.mean(smpl.betas ** 2, axis=1)
    loss['pose_pr'] = prior(smpl.pose)

    # if args.num_joints == 14:
    if args.use_parts:
        loss['part'] = []
        for n, (sc_v, sc_l) in enumerate(zip(scan_verts, scan_part_labels)):
            tot = 0
            # for i in range(args.num_joints):  # we currently use 14 parts
            for i in range(14):  # we currently use 14 parts
                if i not in sc_l:
                    continue
                ind = torch.where(sc_l == i)[0]
                sc_part_points = sc_v[ind].unsqueeze(0)
                sm_part_points = smpl_verts[n][torch.where(smpl_part_labels[n] == i)[0]].unsqueeze(0)
                dist = chamfer_distance(sc_part_points, sm_part_points, w1=1., w2=1.)
                tot += dist
            # loss['part'].append(tot / args.num_joints)
            loss['part'].append(tot / 14)

        loss['part'] = torch.stack(loss['part'])

    return loss
Ejemplo n.º 5
0
def fit_SMPLX(scans,
              pose_files=None,
              gender='male',
              save_path=None,
              display=None):
    """
    :param save_path:
    :param scans: list of scan paths
    :param pose_files:
    :return:
    """
    # Get SMPL faces
    sp = SmplPaths(gender=gender)
    smpl_faces = sp.get_faces()
    th_faces = torch.tensor(smpl_faces.astype('float32'),
                            dtype=torch.long).cuda()

    # Batch size
    batch_sz = len(scans)

    # Set optimization hyper parameters
    iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30

    if False:
        """Test by loading GT SMPL params"""
        betas, pose, trans = torch.tensor(
            GT_SMPL['betas'].astype('float32')).unsqueeze(0), torch.tensor(
                GT_SMPL['pose'].astype('float32')).unsqueeze(0), torch.zeros(
                    (batch_sz, 3))
    else:
        prior = get_prior(gender=gender)
        pose_init = torch.zeros((batch_sz, 72))
        pose_init[:, 3:] = prior.mean
        betas, pose, trans = torch.zeros(
            (batch_sz, 300)), pose_init, torch.zeros((batch_sz, 3))

    # Init SMPL, pose with mean smpl pose, as in ch.registration
    smpl = th_batch_SMPL(batch_sz, betas, pose, trans, faces=th_faces).cuda()

    # Load scans and center them. Once smpl is registered, move it accordingly.
    # Do not forget to change the location of 3D joints/ landmarks accordingly.
    th_scan_meshes, centers = [], []
    for scan in scans:
        print('scan path ...', scan)
        th_scan = tm.from_obj(scan)
        # cent = th_scan.vertices.mean(axis=0)
        # centers.append(cent)
        # th_scan.vertices -= cent
        th_scan.vertices = th_scan.vertices.cuda()
        th_scan.faces = th_scan.faces.cuda()
        th_scan.vertices.requires_grad = False
        th_scan.cuda()
        th_scan_meshes.append(th_scan)

    # Load pose information if pose file is given
    # Bharat: Shouldn't we structure th_pose_3d as [key][batch, ...] as opposed to current [batch][key]? See batch_get_pose_obj() in body_objectives.py
    th_pose_3d = None
    if pose_files is not None:
        th_no_right_hand_visible, th_no_left_hand_visible, th_pose_3d = [], [], []
        for pose_file in pose_files:
            with open(pose_file) as f:
                pose_3d = json.load(f)
                th_no_right_hand_visible.append(
                    np.max(
                        np.array(pose_3d['hand_right_keypoints_3d']).reshape(
                            -1, 4)[:, 3]) < HAND_VISIBLE)
                th_no_left_hand_visible.append(
                    np.max(
                        np.array(pose_3d['hand_left_keypoints_3d']).reshape(
                            -1, 4)[:, 3]) < HAND_VISIBLE)

                pose_3d['pose_keypoints_3d'] = torch.from_numpy(
                    np.array(pose_3d['pose_keypoints_3d']).astype(
                        np.float32).reshape(-1, 4))
                pose_3d['face_keypoints_3d'] = torch.from_numpy(
                    np.array(pose_3d['face_keypoints_3d']).astype(
                        np.float32).reshape(-1, 4))
                pose_3d['hand_right_keypoints_3d'] = torch.from_numpy(
                    np.array(pose_3d['hand_right_keypoints_3d']).astype(
                        np.float32).reshape(-1, 4))
                pose_3d['hand_left_keypoints_3d'] = torch.from_numpy(
                    np.array(pose_3d['hand_left_keypoints_3d']).astype(
                        np.float32).reshape(-1, 4))
            th_pose_3d.append(pose_3d)

        prior_weight = get_prior_weight(th_no_right_hand_visible,
                                        th_no_left_hand_visible).cuda()

        # Optimize pose first
        optimize_pose_only(th_scan_meshes,
                           smpl,
                           pose_iterations,
                           pose_steps_per_iter,
                           th_pose_3d,
                           prior_weight,
                           display=None if display is None else 0)

    # Optimize pose and shape
    optimize_pose_shape(th_scan_meshes,
                        smpl,
                        iterations,
                        steps_per_iter,
                        th_pose_3d,
                        display=None if display is None else 0)

    verts, _, _, _ = smpl()
    th_smpl_meshes = [
        tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts
    ]

    if save_path is not None:
        if not exists(save_path):
            os.makedirs(save_path)

        names = [split(s)[1] for s in scans]

        # Save meshes
        save_meshes(
            th_smpl_meshes,
            [join(save_path, n.replace('.obj', '_smpl.obj')) for n in names])
        save_meshes(th_scan_meshes, [join(save_path, n) for n in names])

        # Save params
        for p, b, t, n in zip(smpl.pose.cpu().detach().numpy(),
                              smpl.betas.cpu().detach().numpy(),
                              smpl.trans.cpu().detach().numpy(), names):
            smpl_dict = {'pose': p, 'betas': b, 'trans': t}
            pkl.dump(
                smpl_dict,
                open(join(save_path, n.replace('.obj', '_smpl.pkl')), 'wb'))

        return smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach(
        ).numpy(), smpl.trans.cpu().detach().numpy()
Ejemplo n.º 6
0
def fit_SMPL(scans, scan_labels, gender='male', save_path=None, scale_file=None, display=None):
    """
    :param save_path:
    :param scans: list of scan paths
    :param pose_files:
    :return:
    """
    # Get SMPL faces
    sp = SmplPaths(gender=gender)
    smpl_faces = sp.get_faces()
    th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).to(DEVICE)

    # Load SMPL parts
    part_labels = pkl.load(open('/BS/bharat-3/work/IPNet/assets/smpl_parts_dense.pkl', 'rb'))
    labels = np.zeros((6890,), dtype='int32')
    for n, k in enumerate(part_labels):
        labels[part_labels[k]] = n
    labels = torch.tensor(labels).unsqueeze(0).to(DEVICE)

    # Load scan parts
    scan_part_labels = []
    for sc_l in scan_labels:
        temp = torch.tensor(np.load(sc_l).astype('int32')).to(DEVICE)
        scan_part_labels.append(temp)

    # Batch size
    batch_sz = len(scans)

    # Set optimization hyper parameters
    iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30

    prior = get_prior(gender=gender, precomputed=True)
    pose_init = torch.zeros((batch_sz, 72))
    pose_init[:, 3:] = prior.mean
    betas, pose, trans = torch.zeros((batch_sz, 300)), pose_init, torch.zeros((batch_sz, 3))

    # Init SMPL, pose with mean smpl pose, as in ch.registration
    smpl = th_batch_SMPL(batch_sz, betas, pose, trans, faces=th_faces).to(DEVICE)
    smpl_part_labels = torch.cat([labels] * batch_sz, axis=0)

    th_scan_meshes, centers = [], []
    for scan in scans:
        print('scan path ...', scan)
        temp = Mesh(filename=scan)
        th_scan = tm.from_tensors(torch.tensor(temp.v.astype('float32'), requires_grad=False, device=DEVICE),
                                  torch.tensor(temp.f.astype('int32'), requires_grad=False, device=DEVICE).long())
        th_scan_meshes.append(th_scan)

    if scale_file is not None:
        for n, sc in enumerate(scale_file):
            dat = np.load(sc, allow_pickle=True)
            th_scan_meshes[n].vertices += torch.tensor(dat[1]).to(DEVICE)
            th_scan_meshes[n].vertices *= torch.tensor(dat[0]).to(DEVICE)

    # Optimize pose first
    optimize_pose_only(th_scan_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels, smpl_part_labels,
                       display=None if display is None else 0)

    # Optimize pose and shape
    optimize_pose_shape(th_scan_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels,
                        display=None if display is None else 0)

    verts, _, _, _ = smpl()
    th_smpl_meshes = [tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts]

    if save_path is not None:
        if not exists(save_path):
            os.makedirs(save_path)

        names = [split(s)[1] for s in scans]

        # Save meshes
        save_meshes(th_smpl_meshes, [join(save_path, n.replace('.ply', '_smpl.obj')) for n in names])
        save_meshes(th_scan_meshes, [join(save_path, n) for n in names])

        # Save params
        for p, b, t, n in zip(smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(),
                              smpl.trans.cpu().detach().numpy(), names):
            smpl_dict = {'pose': p, 'betas': b, 'trans': t}
            pkl.dump(smpl_dict, open(join(save_path, n.replace('.ply', '_smpl.pkl')), 'wb'))

        return smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(), smpl.trans.cpu().detach().numpy()
Ejemplo n.º 7
0
def SMPLD_register(args):
    cfg = config.load_config(args.config, 'configs/default.yaml')
    out_dir = cfg['training']['out_dir']
    generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
    is_cuda = (torch.cuda.is_available() and not args.no_cuda)
    device = torch.device("cuda" if is_cuda else "cpu")

    if args.subject_idx >= 0 and args.sequence_idx >= 0:
        logger, _ = create_logger(generation_dir, phase='reg_subject{}_sequence{}'.format(args.subject_idx, args.sequence_idx), create_tf_logs=False)
    else:
        logger, _ = create_logger(generation_dir, phase='reg_all', create_tf_logs=False)

    # Get dataset
    if args.subject_idx >= 0 and args.sequence_idx >= 0:
        dataset = config.get_dataset('test', cfg, sequence_idx=args.sequence_idx, subject_idx=args.subject_idx)
    else:
        dataset = config.get_dataset('test', cfg)

    batch_size = cfg['generation']['batch_size']

    # Loader
    test_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=1, shuffle=False)

    model_counter = defaultdict(int)

    # Set optimization hyper parameters
    iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30

    inner_dists = []
    outer_dists = []

    for it, data in enumerate(tqdm(test_loader)):
        idxs = data['idx'].cpu().numpy()
        loc = data['points.loc'].cpu().numpy()
        batch_size = idxs.shape[0]
        # Directories to load corresponding informations
        mesh_dir = os.path.join(generation_dir, 'meshes')   # directory for posed and (optionally) unposed implicit outer/inner meshes
        label_dir = os.path.join(generation_dir, 'labels')   # directory for part labels
        register_dir = os.path.join(generation_dir, 'registrations')   # directory for part labels

        if args.use_raw_scan:
            scan_dir = dataset.dataset_folder   # this is the folder that contains CAPE raw scans
        else:
            scan_dir = None

        all_posed_minimal_meshes = []
        all_posed_cloth_meshes = []
        all_posed_vertices = []
        all_unposed_vertices = []
        scan_part_labels = []

        for idx in idxs:
            model_dict = dataset.get_model_dict(idx)

            subset = model_dict['subset']
            subject = model_dict['subject']
            sequence = model_dict['sequence']
            gender = model_dict['gender']
            filebase = os.path.basename(model_dict['data_path'])[:-4]

            folder_name = os.path.join(subset, subject, sequence)
            # TODO: we assume batch size stays the same if one resumes the job
            # can be more flexible to support different batch sizes before and
            # after resume
            register_file = os.path.join(register_dir, folder_name, filebase + 'minimal.registered.ply')
            if os.path.exists(register_file):
                # batch already computed, break
                break

            # points_dict = np.load(model_dict['data_path'])
            # gender = str(points_dict['gender'])

            mesh_dir_ = os.path.join(mesh_dir, folder_name)
            label_dir_ = os.path.join(label_dir, folder_name)

            if scan_dir is not None:
                scan_dir_ = os.path.join(scan_dir, subject, sequence)

            # Load part labels and vertex translations
            label_file_name = filebase + '.minimal.npz'
            label_dict = dict(np.load(os.path.join(label_dir_, label_file_name)))
            labels = torch.tensor(label_dict['part_labels'].astype(np.int64)).to(device)   # part labels for each vertex (14 or 24)
            scan_part_labels.append(labels)

            # Load minimal implicit surfaces
            mesh_file_name = filebase + '.minimal.posed.ply'
            # posed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name))
            posed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False)
            posed_vertices = np.array(posed_mesh.vertices)
            all_posed_vertices.append(posed_vertices)

            posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32'), requires_grad=False, device=device),
                    torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device))
            all_posed_minimal_meshes.append(posed_mesh)

            mesh_file_name = filebase + '.minimal.unposed.ply'
            if os.path.exists(os.path.join(mesh_dir_, mesh_file_name)) and args.init_pose:
                # unposed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name))
                unposed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False)
                unposed_vertices = np.array(unposed_mesh.vertices)
                all_unposed_vertices.append(unposed_vertices)

            if args.use_raw_scan:
                # Load raw scans
                mesh_file_name = filebase + '.ply'
                # posed_mesh = Mesh(filename=os.path.join(scan_dir_, mesh_file_name))
                posed_mesh = trimesh.load(os.path.join(scan_dir_, mesh_file_name), process=False)

                posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32') / 1000, requires_grad=False, device=device),
                        torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device))
                all_posed_cloth_meshes.append(posed_mesh)
            else:
                # Load clothed implicit surfaces
                mesh_file_name = filebase + '.cloth.posed.ply'
                # posed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name))
                posed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False)

                posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32'), requires_grad=False, device=device),
                        torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device))
                all_posed_cloth_meshes.append(posed_mesh)

        if args.num_joints == 24:
            bm = BodyModel(bm_path='body_models/smpl/male/model.pkl', num_betas=10, batch_size=batch_size).to(device)
            parents = bm.kintree_table[0].detach().cpu().numpy()
            labels = bm.weights.argmax(1)
            # Convert 24 parts to 14 parts
            smpl2ipnet = torch.from_numpy(SMPL2IPNET_IDX).to(device)
            labels = smpl2ipnet[labels].clone().unsqueeze(0)
            del bm
        elif args.num_joints == 14:
            with open('body_models/misc/smpl_parts_dense.pkl', 'rb') as f:
                part_labels = pkl.load(f)

            labels = np.zeros((6890,), dtype=np.int64)
            for n, k in enumerate(part_labels):
                labels[part_labels[k]] = n
            labels = torch.tensor(labels).to(device).unsqueeze(0)
        else:
            raise ValueError('Got {} joints but umber of joints can only be either 14 or 24'.format(args.num_joints))

        th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).to(device)

        # We assume loaded meshes are properly scaled and offsetted to the orignal SMPL space,
        if len(all_posed_minimal_meshes) > 0 and len(all_unposed_vertices) == 0:
            # IPNet optimization without vertex traslation
            # raise NotImplementedError('Optimization for IPNet is not implemented yet.')
            if args.num_joints == 24:
                for idx in range(len(scan_part_labels)):
                    scan_part_labels[idx] = smpl2ipnet[scan_part_labels[idx]].clone()

            prior = get_prior(gender=gender, precomputed=True)
            pose_init = torch.zeros((batch_size, 72))
            pose_init[:, 3:] = prior.mean
            betas, pose, trans = torch.zeros((batch_size, 10)), pose_init, torch.zeros((batch_size, 3))

            # Init SMPL, pose with mean smpl pose, as in ch.registration
            smpl = th_batch_SMPL(batch_size, betas, pose, trans, faces=th_faces, gender=gender).to(device)
            smpl_part_labels = torch.cat([labels] * batch_size, axis=0)

            # Optimize pose first
            optimize_pose_only(all_posed_minimal_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels,
                               smpl_part_labels, None, args)

            # Optimize pose and shape
            optimize_pose_shape(all_posed_minimal_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels,
                                None, args)

            inner_vertices, _, _, _ = smpl()

            # Optimize vertices for SMPLD
            init_smpl_meshes = [tm.from_tensors(vertices=v.clone().detach(),
                                                faces=smpl.faces) for v in inner_vertices]
            optimize_offsets(all_posed_cloth_meshes, smpl, init_smpl_meshes, 5, 10, args)

            outer_vertices, _, _, _ = smpl()
        elif len(all_posed_minimal_meshes) > 0:
            # NASA+PTFs optimization with vertex traslations
            # Compute poses from implicit surfaces and correspondences
            # TODO: we could also compute bone-lengths if we train PTFs to predict A-pose with a global translation
            # that equals to the centroid of the pointcloud
            poses = compute_poses(all_posed_vertices, all_unposed_vertices, scan_part_labels, parents, args)
            # Convert 24 parts to 14 parts
            for idx in range(len(scan_part_labels)):
                scan_part_labels[idx] = smpl2ipnet[scan_part_labels[idx]].clone()

            pose_init = torch.from_numpy(poses).float()
            betas, pose, trans = torch.zeros((batch_size, 10)), pose_init, torch.zeros((batch_size, 3))

            # Init SMPL, pose with mean smpl pose, as in ch.registration
            smpl = th_batch_SMPL(batch_size, betas, pose, trans, faces=th_faces, gender=gender).to(device)
            smpl_part_labels = torch.cat([labels] * batch_size, axis=0)

            # Optimize pose first
            optimize_pose_only(all_posed_minimal_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels,
                               smpl_part_labels, None, args)

            # Optimize pose and shape
            optimize_pose_shape(all_posed_minimal_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels,
                                None, args)

            inner_vertices, _, _, _ = smpl()

            # Optimize vertices for SMPLD
            init_smpl_meshes = [tm.from_tensors(vertices=v.clone().detach(),
                                                faces=smpl.faces) for v in inner_vertices]
            optimize_offsets(all_posed_cloth_meshes, smpl, init_smpl_meshes, 5, 10, args)

            outer_vertices, _, _, _ = smpl()
        else:
            inner_vertices = outer_vertices = None

        if args.use_raw_scan:
            for i, idx in enumerate(idxs):
                model_dict = dataset.get_model_dict(idx)

                subset = model_dict['subset']
                subject = model_dict['subject']
                sequence = model_dict['sequence']
                filebase = os.path.basename(model_dict['data_path'])[:-4]

                folder_name = os.path.join(subset, subject, sequence)
                register_dir_ = os.path.join(register_dir, folder_name)
                if not os.path.exists(register_dir_):
                    os.makedirs(register_dir_)

                if not os.path.exists(os.path.join(register_dir_, filebase + 'minimal.registered.ply')):
                    registered_mesh = trimesh.Trimesh(inner_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False)
                    registered_mesh.export(os.path.join(register_dir_, filebase + 'minimal.registered.ply'))

                if not os.path.exists(os.path.join(register_dir_, filebase + 'cloth.registered.ply')):
                    registered_mesh = trimesh.Trimesh(outer_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False)
                    registered_mesh.export(os.path.join(register_dir_, filebase + 'cloth.registered.ply'))
        else:
            # Evaluate registered mesh
            gt_smpl_mesh = data['points.minimal_smpl_vertices'].to(device)
            gt_smpld_mesh = data['points.smpl_vertices'].to(device)
            if inner_vertices is None:
                # if vertices are None, we assume they already exist due to previous runs
                inner_vertices = []
                outer_vertices = []
                for i, idx in enumerate(idxs):

                    model_dict = dataset.get_model_dict(idx)

                    subset = model_dict['subset']
                    subject = model_dict['subject']
                    sequence = model_dict['sequence']
                    filebase = os.path.basename(model_dict['data_path'])[:-4]

                    folder_name = os.path.join(subset, subject, sequence)
                    register_dir_ = os.path.join(register_dir, folder_name)

                    # registered_mesh = Mesh(filename=os.path.join(register_dir_, filebase + 'minimal.registered.ply'))
                    registered_mesh = trimesh.load(os.path.join(register_dir_, filebase + 'minimal.registered.ply'), process=False)
                    registered_v = torch.tensor(registered_mesh.vertices.astype(np.float32), requires_grad=False, device=device)
                    inner_vertices.append(registered_v)

                    # registered_mesh = Mesh(filename=os.path.join(register_dir_, filebase + 'cloth.registered.ply'))
                    registered_mesh = trimesh.load(os.path.join(register_dir_, filebase + 'cloth.registered.ply'), process=False)
                    registered_v = torch.tensor(registered_mesh.vertices.astype(np.float32), requires_grad=False, device=device)
                    outer_vertices.append(registered_v)

                inner_vertices = torch.stack(inner_vertices, dim=0)
                outer_vertices = torch.stack(outer_vertices, dim=0)

            inner_dist = torch.norm(gt_smpl_mesh - inner_vertices, dim=2).mean(-1)
            outer_dist = torch.norm(gt_smpld_mesh - outer_vertices, dim=2).mean(-1)

            for i, idx in enumerate(idxs):
                model_dict = dataset.get_model_dict(idx)

                subset = model_dict['subset']
                subject = model_dict['subject']
                sequence = model_dict['sequence']
                filebase = os.path.basename(model_dict['data_path'])[:-4]

                folder_name = os.path.join(subset, subject, sequence)
                register_dir_ = os.path.join(register_dir, folder_name)
                if not os.path.exists(register_dir_):
                    os.makedirs(register_dir_)

                logger.info('Inner distance for input {}: {} cm'.format(filebase, inner_dist[i].item()))
                logger.info('Outer distance for input {}: {} cm'.format(filebase, outer_dist[i].item()))

                if not os.path.exists(os.path.join(register_dir_, filebase + 'minimal.registered.ply')):
                    registered_mesh = trimesh.Trimesh(inner_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False)
                    registered_mesh.export(os.path.join(register_dir_, filebase + 'minimal.registered.ply'))

                if not os.path.exists(os.path.join(register_dir_, filebase + 'cloth.registered.ply')):
                    registered_mesh = trimesh.Trimesh(outer_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False)
                    registered_mesh.export(os.path.join(register_dir_, filebase + 'cloth.registered.ply'))

            inner_dists.extend(inner_dist.detach().cpu().numpy())
            outer_dists.extend(outer_dist.detach().cpu().numpy())

    logger.info('Mean inner distance: {} cm'.format(np.mean(inner_dists)))
    logger.info('Mean outer distance: {} cm'.format(np.mean(outer_dists)))
Ejemplo n.º 8
0
    def __init__(self,
                 model,
                 device,
                 train_loader,
                 val_loader,
                 exp_name,
                 opt_dict={},
                 optimizer='Adam',
                 checkpoint_number=-1,
                 train_supervised=False):
        """
        :param model: correspondence prediction network
        :param device: cuda or cpu
        :param train_loader:
        :param val_loader:
        :param exp_name:
        :param opt_dict: dict containing optimization specific parameteres
        :param optimizer:
        :param checkpoint_number: load a specific checkpoint, -1 => load latest
        :param train_supervised:
        """
        self.model = model.to(device)
        self.device = device
        self.opt_dict = self.parse_opt_dict(opt_dict)
        self.optimizer_type = optimizer
        self.optimizer = self.init_optimizer(optimizer,
                                             self.model.parameters(),
                                             learning_rate=0.001)

        self.train_data_loader = train_loader
        self.val_data_loader = val_loader
        self.train_supervised = train_supervised
        self.checkpoint_number = checkpoint_number

        # Load vsmpl
        self.vsmpl = VolumetricSMPL(
            '/BS/bharat-2/work/LearntRegistration/test_data/volumetric_smpl_function_64',
            device, 'male')
        sp = SmplPaths(gender='male')
        self.ref_smpl = sp.get_smpl()
        self.template_points = torch.tensor(trimesh.Trimesh(
            vertices=self.ref_smpl.r,
            faces=self.ref_smpl.f).sample(NUM_POINTS).astype('float32'),
                                            requires_grad=False).unsqueeze(0)

        self.pose_prior = get_prior('male', precomputed=True)

        # Load smpl part labels
        with open(
                '/BS/bharat-2/work/LearntRegistration/test_data/smpl_parts_dense.pkl',
                'rb') as f:
            dat = pkl.load(f, encoding='latin-1')
        self.smpl_parts = np.zeros((6890, 1))
        for n, k in enumerate(dat):
            self.smpl_parts[dat[k]] = n

        self.exp_path = join(os.path.dirname(__file__),
                             '../experiments/{}'.format(exp_name))
        self.checkpoint_path = join(self.exp_path,
                                    'checkpoints/'.format(exp_name))
        if not os.path.exists(self.checkpoint_path):
            print(self.checkpoint_path)
            os.makedirs(self.checkpoint_path)
        self.writer = SummaryWriter(
            join(self.exp_path, 'summary'.format(exp_name)))
        self.val_min = None