コード例 #1
0
def getdata(variant):
    skewfit_variant = variant['skewfit_variant']
    print('-------------------------------')
    skewfit_preprocess_variant(skewfit_variant)
    skewfit_variant['render'] = True
    vae_environment = get_envs(skewfit_variant)
    print('done loading vae_env')

    env_class = variant.get('env_class', None)
    env_kwargs = variant.get('env_kwargs', None)
    env_id = variant.get('env_id', None)
    N = variant.get('N', 10000)
    test_p = variant.get('test_p', 0.9)
    use_cached = variant.get('use_cached', True)
    imsize = variant.get('imsize', 84)
    num_channels = variant.get('num_channels', 3)
    show = variant.get('show', False)
    init_camera = variant.get('init_camera', None)
    dataset_path = variant.get('dataset_path', None)
    oracle_dataset_using_set_to_goal = variant.get(
        'oracle_dataset_using_set_to_goal', False)
    random_rollout_data = variant.get('random_rollout_data', False)
    random_and_oracle_policy_data = variant.get(
        'random_and_oracle_policy_data', False)
    random_and_oracle_policy_data_split = variant.get(
        'random_and_oracle_policy_data_split', 0)
    policy_file = variant.get('policy_file', None)
    n_random_steps = variant.get('n_random_steps', 100)
    vae_dataset_specific_env_kwargs = variant.get(
        'vae_dataset_specific_env_kwargs', None)
    save_file_prefix = variant.get('save_file_prefix', None)
    non_presampled_goal_img_is_garbage = variant.get(
        'non_presampled_goal_img_is_garbage', None)
    tag = variant.get('tag', '')
    from multiworld.core.image_env import ImageEnv, unormalize_image
    import rlkit.torch.pytorch_util as ptu
    info = {}
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        N = dataset.shape[0]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        if save_file_prefix is None:
            save_file_prefix = env_id
        if save_file_prefix is None:
            save_file_prefix = env_class.__name__
        filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format(
            save_file_prefix,
            str(N),
            init_camera.__name__ if init_camera else '',
            imsize,
            random_and_oracle_policy_data_split,
            tag,
        )
        if True:
            now = time.time()

            if env_id is not None:
                import gym
                import multiworld
                multiworld.register_all_envs()
                env = gym.make(env_id)
            else:
                if vae_dataset_specific_env_kwargs is None:
                    vae_dataset_specific_env_kwargs = {}
                for key, val in env_kwargs.items():
                    if key not in vae_dataset_specific_env_kwargs:
                        vae_dataset_specific_env_kwargs[key] = val
                env = env_class(**vae_dataset_specific_env_kwargs)
            if not isinstance(env, ImageEnv):
                print("using(ImageEnv)")
                env = ImageEnv(
                    env,
                    imsize,
                    init_camera=init_camera,
                    transpose=True,
                    normalize=True,
                    non_presampled_goal_img_is_garbage=
                    non_presampled_goal_img_is_garbage,
                )
            else:
                imsize = env.imsize
                env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage
            env.reset()
            info['env'] = env
            if random_and_oracle_policy_data:
                policy_file = load_local_or_remote_file(policy_file)
                policy = policy_file['policy']
                policy.to(ptu.device)
            if random_rollout_data:
                from rlkit.exploration_strategies.ou_strategy import OUStrategy
                policy = OUStrategy(env.action_space)
            dataset = np.zeros((N, imsize * imsize * num_channels),
                               dtype=np.uint8)

            for i in range(10):
                NP = []
                if True:
                    print(i)
                    #print('th step')
                    goal = env.sample_goal()
                    # print("goal___________________________")
                    # print(goal)
                    # print("goal___________________________")
                    env.set_to_goal(goal)
                    obs = env._get_obs()
                    #img = img.reshape(3, imsize, imsize).transpose()
                    # img = img[::-1, :, ::-1]
                    # cv2.imshow('img', img)
                    # cv2.waitKey(1)
                    img_1 = obs['image_observation']
                    img_1 = img_1.reshape(3, imsize, imsize).transpose()
                    NP.append(img_1)
                    if i % 3 == 0:
                        cv2.imshow('img1', img_1)
                        cv2.waitKey(1)
                    #img_1_reconstruct = vae_environment._reconstruct_img(obs['image_observation']).transpose()
                    encoded_1 = vae_environment._get_encoded(
                        obs['image_observation'])
                    print(encoded_1)
                    NP.append(encoded_1)
                    img_1_reconstruct = vae_environment._get_img(
                        encoded_1).transpose()
                    NP.append(img_1_reconstruct)
                    #dataset[i, :] = unormalize_image(img)
                    # img_1 = img_1.reshape(3, imsize, imsize).transpose()
                    if i % 3 == 0:
                        cv2.imshow('img1_reconstruction', img_1_reconstruct)
                        cv2.waitKey(1)
                    env.reset()
                    instr = env.generate_new_state(goal)
                    if i % 3 == 0:
                        print(instr)
                    obs = env._get_obs()
                    # obs = env._get_obs()
                    img_2 = obs['image_observation']
                    img_2 = img_2.reshape(3, imsize, imsize).transpose()
                    NP.append(img_2)
                    if i % 3 == 0:
                        cv2.imshow('img2', img_2)
                        cv2.waitKey(1)
                    #img_2_reconstruct = vae_environment._reconstruct_img(obs['image_observation']).transpose()
                    encoded_2 = vae_environment._get_encoded(
                        obs['image_observation'])
                    NP.append(encoded_2)
                    img_2_reconstruct = vae_environment._get_img(
                        encoded_2).transpose()
                    NP.append(img_2_reconstruct)
                    NP.append(instr)
                    # img_2 = img_2.reshape(3, imsize, imsize).transpose()
                    if i % 3 == 0:
                        cv2.imshow('img2_reconstruct', img_2_reconstruct)
                        cv2.waitKey(1)
                    NP = np.array(NP)
                    idx = str(i)
                    name = "/home/xiaomin/Downloads/IFIG_DATA_1/" + idx + ".npy"
                    np.save(open(name, 'wb'), NP)
                    # radius = input('waiting...')

                # #get the in between functions
            import dill
            import pickle
            get_encoded = dill.dumps(vae_environment._get_encoded)
            with open(
                    "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_encoded_1000_epochs_one_puck.txt",
                    "wb") as fp:
                pickle.dump(get_encoded, fp)
            with open(
                    "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_encoded_1000_epochs_one_puck.txt",
                    "rb") as fp:
                b = pickle.load(fp)
            func_get_encoded = dill.loads(b)
            encoded = func_get_encoded(obs['image_observation'])
            print(encoded)
            print('------------------------------')
            get_img = dill.dumps(vae_environment._get_img)
            with open(
                    "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_img_1000_epochs_one_puck.txt",
                    "wb") as fp:
                pickle.dump(get_img, fp)
            with open(
                    "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_img_1000_epochs_one_puck.txt",
                    "rb") as fp:
                c = pickle.load(fp)
            func_get_img = dill.loads(c)

            img_1_reconstruct = func_get_img(encoded).transpose()
            print(img_1_reconstruct)
            #dataset[i, :] = unormalize_image(img)
            # img_1 = img_1.reshape(3, imsize, imsize).transpose()
            cv2.imshow('test', img_1_reconstruct)
            cv2.waitKey(0)

            print("done making training data", filename, time.time() - now)
            np.save(filename, dataset)
コード例 #2
0
def generate_vae_dataset(variant):
    env_class = variant.get('env_class', None)
    env_kwargs = variant.get('env_kwargs', None)
    env_id = variant.get('env_id', None)
    N = variant.get('N', 10000)
    test_p = variant.get('test_p', 0.9)
    use_cached = variant.get('use_cached', True)
    imsize = variant.get('imsize', 84)
    num_channels = variant.get('num_channels', 3)
    show = variant.get('show', False)
    init_camera = variant.get('init_camera', None)
    dataset_path = variant.get('dataset_path', None)
    oracle_dataset_using_set_to_goal = variant.get(
        'oracle_dataset_using_set_to_goal', False)
    random_rollout_data = variant.get('random_rollout_data', False)
    random_and_oracle_policy_data = variant.get(
        'random_and_oracle_policy_data', False)
    random_and_oracle_policy_data_split = variant.get(
        'random_and_oracle_policy_data_split', 0)
    policy_file = variant.get('policy_file', None)
    n_random_steps = variant.get('n_random_steps', 100)
    vae_dataset_specific_env_kwargs = variant.get(
        'vae_dataset_specific_env_kwargs', None)
    save_file_prefix = variant.get('save_file_prefix', None)
    non_presampled_goal_img_is_garbage = variant.get(
        'non_presampled_goal_img_is_garbage', None)
    tag = variant.get('tag', '')
    from multiworld.core.image_env import ImageEnv, unormalize_image
    import rlkit.torch.pytorch_util as ptu
    info = {}
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        N = dataset.shape[0]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        if save_file_prefix is None:
            save_file_prefix = env_id
        if save_file_prefix is None:
            save_file_prefix = env_class.__name__
        filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format(
            save_file_prefix,
            str(N),
            init_camera.__name__ if init_camera else '',
            imsize,
            random_and_oracle_policy_data_split,
            tag,
        )
        if use_cached and osp.isfile(filename):
            dataset = np.load(filename)
            print("loaded data from saved file", filename)
        else:
            now = time.time()

            if env_id is not None:
                import gym
                import multiworld
                multiworld.register_all_envs()
                env = gym.make(env_id)
            else:
                if vae_dataset_specific_env_kwargs is None:
                    vae_dataset_specific_env_kwargs = {}
                for key, val in env_kwargs.items():
                    if key not in vae_dataset_specific_env_kwargs:
                        vae_dataset_specific_env_kwargs[key] = val
                env = env_class(**vae_dataset_specific_env_kwargs)
            if not isinstance(env, ImageEnv):
                env = ImageEnv(
                    env,
                    imsize,
                    init_camera=init_camera,
                    transpose=True,
                    normalize=True,
                    non_presampled_goal_img_is_garbage=
                    non_presampled_goal_img_is_garbage,
                )
            else:
                imsize = env.imsize
                env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage
            env.reset()
            info['env'] = env
            if random_and_oracle_policy_data:
                policy_file = load_local_or_remote_file(policy_file)
                policy = policy_file['policy']
                policy.to(ptu.device)
            if random_rollout_data:
                from rlkit.exploration_strategies.ou_strategy import OUStrategy
                policy = OUStrategy(env.action_space)
            dataset = np.zeros((N, imsize * imsize * num_channels),
                               dtype=np.uint8)
            for i in range(10000):
                NP = []
                if oracle_dataset_using_set_to_goal:
                    print(i)
                    #print('th step')
                    goal = env.sample_goal()
                    env.set_to_goal(goal)
                    obs = env._get_obs()
                    #img = img.reshape(3, imsize, imsize).transpose()
                    # img = img[::-1, :, ::-1]
                    # cv2.imshow('img', img)
                    # cv2.waitKey(1)
                    img_1 = obs['image_observation']
                    NP.append(img_1)
                    #dataset[i, :] = unormalize_image(img)
                    img_1 = img_1.reshape(3, imsize, imsize).transpose()
                    if i % 3 == 0:
                        cv2.imshow('img1', img_1)
                        cv2.waitKey(1)
                    env.reset()
                    instr = env.generate_new_state(goal)
                    if i % 3 == 0:
                        print(instr)
                    obs = env._get_obs()
                    # obs = env._get_obs()
                    img_2 = obs['image_observation']
                    NP.append(img_2)
                    NP.append(instr)
                    img_2 = img_2.reshape(3, imsize, imsize).transpose()
                    if i % 3 == 0:
                        cv2.imshow('img2', img_2)
                        cv2.waitKey(1)
                    NP = np.array(NP)
                    print(NP)
                    idx = str(i)
                    name = "/home/xiaomin/Downloads/IFIG_DATA_1/" + idx + ".npy"
                    np.save(open(name, 'wb'), NP)
                    # radius = input('waiting...')
            print("done making training data", filename, time.time() - now)
            np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info