示例#1
0
def grab_new_objs(grabnet, objs_path, rot=True, n_samples=10, scale=1.):
    grabnet.coarse_net.eval()
    grabnet.refine_net.eval()

    rh_model = mano.load(model_path=grabnet.cfg.rhm_path,
                         model_type='mano',
                         num_pca_comps=45,
                         batch_size=n_samples,
                         flat_hand_mean=True).to(grabnet.device)

    grabnet.refine_net.rhm_train = rh_model

    grabnet.logger(f'################# \n'
                   f'Grabbing the object!'
                   )

    bps = bps_torch(custom_basis=grabnet.bps)

    if not isinstance(objs_path, list):
        objs_path = [objs_path]

    for new_obj in objs_path:

        rand_rotdeg = np.random.random([n_samples, 3]) * np.array([360, 360, 360])

        rand_rotmat = euler(rand_rotdeg)
        dorig = {'bps_object': [],
                 'verts_object': [],
                 'mesh_object': [],
                 'rotmat': []}

        for samples in range(n_samples):
            verts_obj, mesh_obj, rotmat = load_obj_verts(new_obj, rand_rotmat[samples], rndrotate=rot, scale=scale)

            bps_object = bps.encode(verts_obj, feature_type='dists')['dists']

            dorig['bps_object'].append(bps_object.to(grabnet.device))
            dorig['verts_object'].append(torch.from_numpy(verts_obj.astype(np.float32)).unsqueeze(0))
            dorig['mesh_object'].append(mesh_obj)
            dorig['rotmat'].append(rotmat)
            obj_name = os.path.basename(new_obj)

        dorig['bps_object'] = torch.cat(dorig['bps_object'])
        dorig['verts_object'] = torch.cat(dorig['verts_object'])

        save_dir = os.path.join(grabnet.cfg.work_dir, 'grab_new_objects')
        grabnet.logger(f'#################\n'
                       f'                   \n'
                       f'Saving results for the {obj_name.upper()}'
                       f'                      \n')

        gen_meshes = get_meshes(dorig=dorig,
                                coarse_net=grabnet.coarse_net,
                                refine_net=grabnet.refine_net,
                                rh_model=rh_model,
                                save=False,
                                save_dir=save_dir
                                )

        torch.save(gen_meshes, 'data/grabnet_data/meshes.pt')
示例#2
0
def grab_new_objs(grabnet, pkl_path, rot=True, n_samples=5, scale=1.):
    grabnet.coarse_net.eval()
    grabnet.refine_net.eval()

    rh_model = mano.load(model_path=grabnet.cfg.rhm_path,
                         model_type='mano',
                         num_pca_comps=45,
                         batch_size=n_samples,
                         flat_hand_mean=True).to(grabnet.device)

    rh_model_pkl = mano.load(model_path=grabnet.cfg.rhm_path,
                             model_type='mano',
                             num_pca_comps=15,
                             batch_size=n_samples,
                             flat_hand_mean=False).to(grabnet.device)

    grabnet.refine_net.rhm_train = rh_model

    grabnet.logger(f'################# \n'
                   f'Colors Guide:'
                   f'                   \n'
                   f'Gray  --->  GrabNet generated grasp\n')

    bps = bps_torch(custom_basis=grabnet.bps)

    all_samples = pickle.load(open(pkl_path, 'rb'))

    if args.vis:
        print('Shuffling!!!')
        random.shuffle(all_samples)

    all_samples = all_samples[:args.num]
    all_data = []

    for idx, new_obj in enumerate(tqdm(all_samples)):
        print('idx', idx)
        ho = new_obj['ho_aug']

        obj_centroid = ho.obj_verts.mean(0)
        ho.obj_verts = np.array(ho.obj_verts) - obj_centroid
        ho.hand_verts = np.array(ho.hand_verts) - obj_centroid
        ho.hand_mTc = np.array(ho.hand_mTc)
        ho.hand_mTc[:3, 3] = ho.hand_mTc[:3, 3] - obj_centroid

        rand_rotdeg = np.random.random([n_samples, 3]) * np.array([0, 0, 0])

        rand_rotmat = euler(rand_rotdeg)
        dorig = {
            'bps_object': [],
            'verts_object': [],
            'mesh_object': [],
            'rotmat': []
        }

        for samples in range(n_samples):

            verts_obj, mesh_obj, rotmat = load_obj_verts(ho,
                                                         rand_rotmat[samples],
                                                         rndrotate=rot,
                                                         scale=scale)

            bps_object = bps.encode(verts_obj, feature_type='dists')['dists']

            dorig['bps_object'].append(bps_object.to(grabnet.device))
            dorig['verts_object'].append(
                torch.from_numpy(verts_obj.astype(np.float32)).unsqueeze(0))
            dorig['mesh_object'].append(mesh_obj)
            dorig['rotmat'].append(rotmat)
            obj_name = 'test1'

        dorig['bps_object'] = torch.cat(dorig['bps_object'])
        dorig['verts_object'] = torch.cat(dorig['verts_object'])

        save_dir = os.path.join(grabnet.cfg.work_dir, 'grab_new_objects')
        # grabnet.logger(f'#################\n'
        #                       f'                   \n'
        #                       f'Showing results for the {obj_name.upper()}'
        #                       f'                      \n')

        verts_out, joints_out = vis_results(ho,
                                            dorig=dorig,
                                            coarse_net=grabnet.coarse_net,
                                            refine_net=grabnet.refine_net,
                                            rh_model=rh_model,
                                            save=False,
                                            save_dir=save_dir,
                                            rh_model_pkl=rh_model_pkl,
                                            vis=args.vis)

        ho.obj_verts = np.array(ho.obj_verts) + obj_centroid
        ho.hand_verts = np.array(ho.hand_verts) + obj_centroid
        ho.hand_mTc = np.array(ho.hand_mTc)
        ho.hand_mTc[:3, 3] = ho.hand_mTc[:3, 3] + obj_centroid

        verts_out = np.array(
            verts_out.detach().squeeze().numpy()) + obj_centroid
        joints_out = np.array(
            joints_out.detach().squeeze().numpy()) + obj_centroid

        new_ho = hand_object.HandObject()
        new_ho.load_from_verts(verts_out, new_obj['ho_gt'].obj_faces,
                               new_obj['ho_gt'].obj_verts)
        all_data.append({
            'gt_ho': new_obj['ho_gt'],
            'in_ho': new_obj['ho_aug'],
            'out_verts': verts_out,
            'out_joints': joints_out
        })

    out_file = 'fitted_grabnet.pkl'
    print('Saving to {}. Len {}'.format(out_file, len(all_data)))
    pickle.dump(all_data, open(out_file, 'wb'))