示例#1
0
    def __init__(
            self,
            trainer,
            replay_buffer,
            demo_train_buffer,
            demo_test_buffer,
            model=None,
            model_path=None,
            input_model=None,
            input_model_path=None,
            reward_fn=None,
            env=None,
            demo_paths=[],  # list of dicts
            normalize=False,
            demo_train_split=0.9,
            demo_data_split=1,
            add_demos_to_replay_buffer=True,
            condition_input_encoding=False,
            bc_num_pretrain_steps=0,
            bc_batch_size=64,
            bc_weight=1.0,
            rl_weight=1.0,
            q_num_pretrain_steps=0,
            weight_decay=0,
            eval_policy=None,
            recompute_reward=False,
            object_list=None,
            env_info_key=None,
            obs_key=None,
            load_terminals=True,
            delete_after_loading=False,
            data_filter_fn=lambda x:
        True,  # Return true to add path, false to ignore it
            **kwargs):
        super().__init__(trainer, replay_buffer, demo_train_buffer,
                         demo_test_buffer, demo_paths, demo_train_split,
                         demo_data_split, add_demos_to_replay_buffer,
                         bc_num_pretrain_steps, bc_batch_size, bc_weight,
                         rl_weight, q_num_pretrain_steps, weight_decay,
                         eval_policy, recompute_reward, env_info_key, obs_key,
                         load_terminals, delete_after_loading, data_filter_fn,
                         **kwargs)

        if model is None:
            self.model = load_local_or_remote_file(
                model_path, delete_after_loading=delete_after_loading)
        else:
            self.model = model

        if input_model is None:
            self.input_model = load_local_or_remote_file(
                input_model_path, delete_after_loading=delete_after_loading)
        else:
            self.input_model = input_model
        self.condition_input_encoding = condition_input_encoding
        self.reward_fn = reward_fn
        self.normalize = normalize
        self.object_list = object_list
        self.env = env
示例#2
0
    def __init__(
        self,
        wrapped_env,
        clusterer,
        mode='train',
        reward_params=None,
    ):
        self.quick_init(locals())
        super().__init__(wrapped_env)
        if type(clusterer) is str:
            self.clusterer = load_local_or_remote_file(clusterer)
        else:
            self.clusterer = clusterer
        self._num_clusters = self.clusterer.num_clusters
        self.task = {}
        self.reward_params = reward_params
        self.reward_type = self.reward_params.get('type', 's_given_z')

        spaces = copy.deepcopy(self.wrapped_env.observation_space.spaces)
        spaces['context'] = Discrete(n=self._num_clusters)
        self.observation_space = Dict(spaces)

        assert self.reward_type == 'wrapped_env' or self.reward_type == 's_given_z'

        self._set_clusterer_attributes()
示例#3
0
    def __init__(
        self,
        wrapped_env,
        disc,
        mode='train',
        reward_params=None,
        unsupervised_reward_weight=0.,
        reward_weight=0.,
        noise_scale=0.,
    ):
        self.quick_init(locals())
        super().__init__(wrapped_env)
        if type(disc) is str:
            self.disc = load_local_or_remote_file(disc)
        else:
            self.disc = disc
        self._num_skills = self.disc.num_skills
        self._p_z = np.full(self._num_skills, 1.0 / self._num_skills)
        self.task = {'context': -1}
        self.reward_params = reward_params
        self.reward_type = self.reward_params.get('type', 'diayn')
        # TODO: Check that TIAYN reward is set properly.
        spaces = copy.deepcopy(self.wrapped_env.observation_space.spaces)
        spaces['context'] = Discrete(n=self._num_skills)
        self.observation_space = Dict(spaces)
        self.unsupervised_reward_weight = unsupervised_reward_weight
        self.reward_weight = reward_weight
        self.noise_scale = noise_scale

        assert self.reward_type == 'wrapped_env' or self.reward_type == 'diayn' or self.reward_type == 'wrapped_env + diayn' or self.reward_type == 'tiayn' or self.reward_type == 'wrapped_env + tiayn'
示例#4
0
def main():
    debug = True
    dry = False
    mode = 'here_no_doodad'
    suffix = ''
    nseeds = 1
    gpu = True

    path_parts = __file__.split('/')
    suffix = '' if suffix is None else '--{}'.format(suffix)
    exp_name = 'pearl-awac-{}--{}{}'.format(
        path_parts[-2].replace('_', '-'),
        path_parts[-1].split('.')[0].replace('_', '-'),
        suffix,
    )

    if debug or dry:
        exp_name = 'dev--' + exp_name
        mode = 'here_no_doodad'
        nseeds = 1

    if dry:
        mode = 'here_no_doodad'

    print(exp_name)

    task_data = load_local_or_remote_file(
        "examples/smac/ant_tasks.joblib",  # TODO: update to point to correct file
        file_type='joblib')
    tasks = task_data['tasks']
    search_space = {
        'seed': list(range(nseeds)),
    }
    variant = DEFAULT_PEARL_CONFIG.copy()
    variant["env_name"] = "ant-dir"
    # variant["train_task_idxs"] = list(range(100))
    # variant["eval_task_idxs"] = list(range(100, 120))
    variant["env_params"]["fixed_tasks"] = [t['goal'] for t in tasks]
    variant["env_params"]["direction_in_degrees"] = True
    variant["trainer_kwargs"]["train_context_decoder"] = True
    variant["trainer_kwargs"]["backprop_q_loss_into_encoder"] = True
    variant[
        "saved_tasks_path"] = "examples/smac/ant_tasks.joblib"  # TODO: update to point to correct file

    sweeper = hyp.DeterministicHyperparameterSweeper(
        search_space,
        default_parameters=variant,
    )
    for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
        variant['exp_id'] = exp_id
        run_experiment(
            pearl_experiment,
            unpack_variant=True,
            exp_prefix=exp_name,
            mode=mode,
            variant=variant,
            time_in_mins=3 * 24 * 60 - 1,
            use_gpu=gpu,
        )
示例#5
0
def resume(variant):
    data = load_local_or_remote_file(variant.get("pretrained_algorithm_path"), map_location="cuda")
    algo = data['algorithm']

    algo.num_epochs = variant['num_epochs']

    post_pretrain_hyperparams = variant["trainer_kwargs"].get("post_pretrain_hyperparams", {})
    algo.trainer.set_algorithm_weights(**post_pretrain_hyperparams)

    algo.train()
示例#6
0
    def load_demo_path(self, demo_path, on_policy=True):
        data = list(load_local_or_remote_file(demo_path))
        N = int(len(data) * self.demo_train_split)
        print("using", N, "paths for training")

        if self.add_demos_to_replay_buffer:
            for path in data[:N]:
                self.load_path(path, self.replay_buffer)
        if on_policy:
            for path in data[:N]:
                self.load_path(path, self.demo_train_buffer)
            for path in data[N:]:
                self.load_path(path, self.demo_test_buffer)
示例#7
0
 def __init__(
         self,
         datapath,
         representation_size,
         initialize_encodings=True,  # Set to true if you plan to re-encode presampled images
 ):
     self._presampled_goals = load_local_or_remote_file(datapath)
     self.representation_size = representation_size
     self._num_presampled_goals = self._presampled_goals[list(
         self._presampled_goals)[0]].shape[0]
     if initialize_encodings:
         self._presampled_goals['initial_latent_state'] = np.zeros(
             (self._num_presampled_goals, self.representation_size))
     self._set_spaces()
def encode_dataset(dataset_path):
    data = load_local_or_remote_file(dataset_path)
    data = data.item()

    all_data = []

    vqvae.to('cpu')
    for i in tqdm(range(data["observations"].shape[0])):
        obs = ptu.from_numpy(data["observations"][i] / 255.0)
        latent = vqvae.encode(obs, cont=False)
        all_data.append(latent)
    vqvae.to('cuda')

    encodings = ptu.get_numpy(torch.cat(all_data, dim=0))
    return encodings
示例#9
0
    def encode_dataset(path, object_list):
        data = load_local_or_remote_file(path)
        data = data.item()
        data = data_filter_fn(data)

        all_data = []
        n = min(data["observations"].shape[0], data_size)

        for i in tqdm(range(n)):
            obs = ptu.from_numpy(data["observations"][i] / 255.0)
            latent = vqvae.encode(obs, cont=False)
            all_data.append(latent)

        encodings = ptu.get_numpy(torch.stack(all_data, dim=0))
        return encodings
示例#10
0
文件: common.py 项目: anair13/rlkit
def concatenate_datasets(data_list):
    from rlkit.util.io import load_local_or_remote_file
    obs, envs, dataset = [], [], {}
    for path in data_list:
        curr_data = load_local_or_remote_file(path)
        curr_data = curr_data.item()
        n_random_steps = curr_data['observations'].shape[1]
        imlength = curr_data['observations'].shape[2]
        curr_data['env'] = np.repeat(curr_data['env'], n_random_steps, axis=0)
        curr_data['observations'] = curr_data['observations'].reshape(
            -1, 1, imlength)
        obs.append(curr_data['observations'])
        envs.append(curr_data['env'])
    dataset['observations'] = np.concatenate(obs, axis=0)
    dataset['env'] = np.concatenate(envs, axis=0)
    return dataset
示例#11
0
    def load_demo_path(self,
                       path,
                       is_demo,
                       obs_dict,
                       train_split=None,
                       data_split=None,
                       sync_dir=None):
        print("loading off-policy path", path)

        if sync_dir is not None:
            sync_down_folder(sync_dir)
            paths = glob.glob(get_absolute_path(path))
        else:
            paths = [path]

        data = []

        for filename in paths:
            data.extend(
                list(
                    load_local_or_remote_file(
                        filename,
                        delete_after_loading=self.delete_after_loading)))

        # if not is_demo:
        # data = [data]
        # random.shuffle(data)

        if train_split is None:
            train_split = self.demo_train_split

        if data_split is None:
            data_split = self.demo_data_split

        M = int(len(data) * train_split * data_split)
        N = int(len(data) * data_split)
        print("using", N, "paths for training")

        if self.add_demos_to_replay_buffer:
            for path in data[:M]:
                self.load_path(path, self.replay_buffer, obs_dict=obs_dict)

        if is_demo:
            for path in data[:M]:
                self.load_path(path, self.demo_train_buffer, obs_dict=obs_dict)
            for path in data[M:N]:
                self.load_path(path, self.demo_test_buffer, obs_dict=obs_dict)
示例#12
0
    def __init__(
            self,
            wrapped_env,
            disc,
            mode='train',
            reward_params=None,
    ):
        self.quick_init(locals())
        super().__init__(wrapped_env)
        if type(disc) is str:
            self.disc = load_local_or_remote_file(disc)
        else:
            self.disc = disc
        self._num_skills = self.disc.num_skills
        self._p_z = np.full(self._num_skills, 1.0 / self._num_skills)
        self.task = {'context': -1}
        self.reward_params = reward_params
        self.reward_type = self.reward_params.get('type', 'diayn')

        spaces = copy.deepcopy(self.wrapped_env.observation_space.spaces)
        spaces['context'] = Discrete(n=self._num_skills)
        self.observation_space = Dict(spaces)

        assert self.reward_type == 'wrapped_env' or self.reward_type == 'diayn'
示例#13
0
def experiment(variant):
    if variant.get("pretrained_algorithm_path", False):
        resume(variant)
        return

    normalize_env = variant.get("normalize_env", True)
    env_id = variant.get("env_id", None)
    env_params = ENV_PARAMS.get(env_id, {})
    variant.update(env_params)
    env_class = variant.get("env_class", None)
    env_kwargs = variant.get("env_kwargs", {})

    expl_env = make(env_id, env_class, env_kwargs, normalize_env)
    eval_env = make(env_id, env_class, env_kwargs, normalize_env)

    if variant.get("add_env_demos", False):
        variant["path_loader_kwargs"]["demo_paths"].append(
            variant["env_demo_path"])
    if variant.get("add_env_offpolicy_data", False):
        variant["path_loader_kwargs"]["demo_paths"].append(
            variant["env_offpolicy_data_path"])

    path_loader_kwargs = variant.get("path_loader_kwargs", {})
    stack_obs = path_loader_kwargs.get("stack_obs", 1)
    if stack_obs > 1:
        expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs)
        eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs)

    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    if hasattr(expl_env, "info_sizes"):
        env_info_sizes = expl_env.info_sizes
    else:
        env_info_sizes = dict()

    qf_kwargs = variant.get("qf_kwargs", {})
    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **qf_kwargs)
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **qf_kwargs)
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **qf_kwargs)
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **qf_kwargs)

    policy_class = variant.get("policy_class", TanhGaussianPolicy)
    policy_kwargs = variant["policy_kwargs"]
    policy_path = variant.get("policy_path", False)
    if policy_path:
        policy = load_local_or_remote_file(policy_path)
    else:
        policy = policy_class(
            obs_dim=obs_dim,
            action_dim=action_dim,
            **policy_kwargs,
        )
    buffer_policy_path = variant.get("buffer_policy_path", False)
    if buffer_policy_path:
        buffer_policy = load_local_or_remote_file(buffer_policy_path)
    else:
        buffer_policy_class = variant.get("buffer_policy_class", policy_class)
        buffer_policy = buffer_policy_class(
            obs_dim=obs_dim,
            action_dim=action_dim,
            **variant.get("buffer_policy_kwargs", policy_kwargs),
        )

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )

    expl_policy = policy
    exploration_kwargs = variant.get("exploration_kwargs", {})
    if exploration_kwargs:
        if exploration_kwargs.get("deterministic_exploration", False):
            expl_policy = MakeDeterministic(policy)

        exploration_strategy = exploration_kwargs.get("strategy", None)
        if exploration_strategy is None:
            pass
        elif exploration_strategy == "ou":
            es = OUStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs["noise"],
                min_sigma=exploration_kwargs["noise"],
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        elif exploration_strategy == "gauss_eps":
            es = GaussianAndEpsilonStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs["noise"],
                min_sigma=exploration_kwargs["noise"],  # constant sigma
                epsilon=0,
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        else:
            error

    main_replay_buffer_kwargs = dict(
        max_replay_buffer_size=variant["replay_buffer_size"],
        env=expl_env,
    )
    replay_buffer_kwargs = dict(
        max_replay_buffer_size=variant["replay_buffer_size"],
        env=expl_env,
    )

    replay_buffer = variant.get("replay_buffer_class",
                                EnvReplayBuffer)(**main_replay_buffer_kwargs, )
    if variant.get("use_validation_buffer", False):
        train_replay_buffer = replay_buffer
        validation_replay_buffer = variant.get(
            "replay_buffer_class",
            EnvReplayBuffer)(**main_replay_buffer_kwargs, )
        replay_buffer = SplitReplayBuffer(train_replay_buffer,
                                          validation_replay_buffer, 0.9)

    trainer_class = variant.get("trainer_class", AWACTrainer)
    trainer = trainer_class(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        buffer_policy=buffer_policy,
        **variant["trainer_kwargs"],
    )
    if variant["collection_mode"] == "online":
        expl_path_collector = MdpStepCollector(
            expl_env,
            policy,
        )
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant["max_path_length"],
            batch_size=variant["batch_size"],
            num_epochs=variant["num_epochs"],
            num_eval_steps_per_epoch=variant["num_eval_steps_per_epoch"],
            num_expl_steps_per_train_loop=variant[
                "num_expl_steps_per_train_loop"],
            num_trains_per_train_loop=variant["num_trains_per_train_loop"],
            min_num_steps_before_training=variant[
                "min_num_steps_before_training"],
        )
    else:
        expl_path_collector = MdpPathCollector(
            expl_env,
            expl_policy,
        )
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant["max_path_length"],
            batch_size=variant["batch_size"],
            num_epochs=variant["num_epochs"],
            num_eval_steps_per_epoch=variant["num_eval_steps_per_epoch"],
            num_expl_steps_per_train_loop=variant[
                "num_expl_steps_per_train_loop"],
            num_trains_per_train_loop=variant["num_trains_per_train_loop"],
            min_num_steps_before_training=variant[
                "min_num_steps_before_training"],
        )
    algorithm.to(ptu.device)

    demo_train_buffer = EnvReplayBuffer(**replay_buffer_kwargs, )
    demo_test_buffer = EnvReplayBuffer(**replay_buffer_kwargs, )

    if variant.get("save_video", False):
        if variant.get("presampled_goals", None):
            variant["image_env_kwargs"][
                "presampled_goals"] = load_local_or_remote_file(
                    variant["presampled_goals"]).item()

        def get_img_env(env):
            renderer = EnvRenderer(**variant["renderer_kwargs"])
            img_env = InsertImageEnv(GymToMultiEnv(env), renderer=renderer)

        image_eval_env = ImageEnv(GymToMultiEnv(eval_env),
                                  **variant["image_env_kwargs"])
        # image_eval_env = get_img_env(eval_env)
        image_eval_path_collector = ObsDictPathCollector(
            image_eval_env,
            eval_policy,
            observation_key="state_observation",
        )
        image_expl_env = ImageEnv(GymToMultiEnv(expl_env),
                                  **variant["image_env_kwargs"])
        # image_expl_env = get_img_env(expl_env)
        image_expl_path_collector = ObsDictPathCollector(
            image_expl_env,
            expl_policy,
            observation_key="state_observation",
        )
        video_func = VideoSaveFunction(
            image_eval_env,
            variant,
            image_expl_path_collector,
            image_eval_path_collector,
        )
        algorithm.post_train_funcs.append(video_func)
    if variant.get("save_paths", False):
        algorithm.post_train_funcs.append(save_paths)
    if variant.get("load_demos", False):
        path_loader_class = variant.get("path_loader_class", MDPPathLoader)
        path_loader = path_loader_class(
            trainer,
            replay_buffer=replay_buffer,
            demo_train_buffer=demo_train_buffer,
            demo_test_buffer=demo_test_buffer,
            **path_loader_kwargs,
        )
        path_loader.load_demos()
    if variant.get("load_env_dataset_demos", False):
        path_loader_class = variant.get("path_loader_class", HDF5PathLoader)
        path_loader = path_loader_class(
            trainer,
            replay_buffer=replay_buffer,
            demo_train_buffer=demo_train_buffer,
            demo_test_buffer=demo_test_buffer,
            **path_loader_kwargs,
        )
        path_loader.load_demos(expl_env.get_dataset())
    if variant.get("save_initial_buffers", False):
        buffers = dict(
            replay_buffer=replay_buffer,
            demo_train_buffer=demo_train_buffer,
            demo_test_buffer=demo_test_buffer,
        )
        buffer_path = osp.join(logger.get_snapshot_dir(), "buffers.p")
        pickle.dump(buffers, open(buffer_path, "wb"))
    if variant.get("pretrain_buffer_policy", False):
        trainer.pretrain_policy_with_bc(
            buffer_policy,
            replay_buffer.train_replay_buffer,
            replay_buffer.validation_replay_buffer,
            10000,
            label="buffer",
        )
    if variant.get("pretrain_policy", False):
        trainer.pretrain_policy_with_bc(
            policy,
            demo_train_buffer,
            demo_test_buffer,
            trainer.bc_num_pretrain_steps,
        )
    if variant.get("pretrain_rl", False):
        trainer.pretrain_q_with_bc_data()
    if variant.get("save_pretrained_algorithm", False):
        p_path = osp.join(logger.get_snapshot_dir(), "pretrain_algorithm.p")
        pt_path = osp.join(logger.get_snapshot_dir(), "pretrain_algorithm.pt")
        data = algorithm._get_snapshot()
        data["algorithm"] = algorithm
        torch.save(data, open(pt_path, "wb"))
        torch.save(data, open(p_path, "wb"))
    if variant.get("train_rl", True):
        algorithm.train()
def generate_LSTM_vae_only_dataset(variant,
                                   segmented=False,
                                   segmentation_method='color'):
    from multiworld.core.image_env import ImageEnv, unormalize_image

    env_id = variant.get('env_id', None)
    N = variant.get('N', 500)
    test_p = variant.get('test_p', 0.9)
    imsize = variant.get('imsize', 48)
    num_channels = variant.get('num_channels', 3)
    init_camera = variant.get('init_camera', None)
    occlusion_prob = variant.get('occlusion_prob', 0)
    occlusion_level = variant.get('occlusion_level', 0.5)
    segmentation_kwargs = variant.get('segmentation_kwargs', {})
    if segmentation_kwargs.get('segment') is not None:
        segmented = segmentation_kwargs.get('segment')

    assert env_id is not None, 'you must provide an env id!'

    obj = 'puck-pos'
    if env_id == 'SawyerDoorHookResetFreeEnv-v1':
        obj = 'door-angle'

    pjhome = os.environ['PJHOME']
    if segmented:
        if 'unet' in segmentation_method:
            seg_name = 'seg-unet'
        else:
            seg_name = 'seg-' + segmentation_method
    else:
        seg_name = 'no-seg'

    if env_id == 'SawyerDoorHookResetFreeEnv-v1':
        seg_name += '-2'

    data_file_path = osp.join(
        pjhome, 'data/local/pre-train-lstm',
        'vae-only-{}-{}-{}-{}-{}.npy'.format(env_id, seg_name, N,
                                             occlusion_prob, occlusion_level))
    obj_state_path = osp.join(
        pjhome, 'data/local/pre-train-lstm',
        'vae-only-{}-{}-{}-{}-{}-{}.npy'.format(env_id, seg_name, N,
                                                occlusion_prob,
                                                occlusion_level, obj))

    print(data_file_path)
    if osp.exists(data_file_path):
        all_data = np.load(data_file_path)
        if len(all_data) >= N:
            print("load stored data at: ", data_file_path)
            n = int(len(all_data) * test_p)
            train_dataset = all_data[:n]
            test_dataset = all_data[n:]
            obj_states = np.load(obj_state_path)
            info = {'obj_state': obj_states}
            return train_dataset, test_dataset, info

    if segmented:
        print(
            "generating lstm vae pretrain only dataset with segmented images using method: ",
            segmentation_method)
        if segmentation_method == 'unet':
            segment_func = segment_image_unet
        else:
            raise NotImplementedError
    else:
        print("generating lstm vae pretrain only dataset with original images")

    info = {}
    dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8)
    imgs = []
    obj_states = None

    if env_id == 'SawyerDoorHookResetFreeEnv-v1':
        from rlkit.util.io import load_local_or_remote_file
        pjhome = os.environ['PJHOME']
        pre_sampled_goal_path = osp.join(
            pjhome, 'data/local/pre-train-vae/door_original_dataset.npy')
        goal_dict = np.load(pre_sampled_goal_path, allow_pickle=True).item()
        imgs = goal_dict['image_desired_goal']
        door_angles = goal_dict['state_desired_goal'][:, -1]
        obj_states = door_angles[:, np.newaxis]
    elif env_id == 'SawyerPickupEnvYZEasy-v0':
        from rlkit.util.io import load_local_or_remote_file
        pjhome = os.environ['PJHOME']
        pre_sampled_goal_path = osp.join(
            pjhome, 'data/local/pre-train-vae/pickup-original-dataset.npy')
        goal_dict = load_local_or_remote_file(pre_sampled_goal_path).item()
        imgs = goal_dict['image_desired_goal']
        puck_pos = goal_dict['state_desired_goal'][:, 3:]
        obj_states = puck_pos

    else:
        import gym
        import multiworld
        multiworld.register_all_envs()
        env = gym.make(env_id)

        if not isinstance(env, ImageEnv):
            env = ImageEnv(
                env,
                imsize,
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        env.reset()
        info['env'] = env

        puck_pos = np.zeros((N, 2), dtype=np.float)
        for i in range(N):
            print("lstm vae pretrain only dataset generation, number: ", i)
            if env_id == 'SawyerPushHurdle-v0':
                obs, puck_p = _generate_sawyerhurdle_dataset(
                    env, return_puck_pos=True, segmented=segmented)
            elif env_id == 'SawyerPushHurdleMiddle-v0':
                obs, puck_p = _generate_sawyerhurdlemiddle_dataset(
                    env, return_puck_pos=True)
            elif env_id == 'SawyerPushNIPSEasy-v0':
                obs, puck_p = _generate_sawyerpushnipseasy_dataset(
                    env, return_puck_pos=True)
            elif env_id == 'SawyerPushHurdleResetFreeEnv-v0':
                obs, puck_p = _generate_sawyerhurldeblockresetfree_dataset(
                    env, return_puck_pos=True)
            else:
                raise NotImplementedError
            img = obs[
                'image_observation']  # NOTE: this is already normalized image, of detype np.float64.
            imgs.append(img)
            puck_pos[i] = puck_p

        obj_states = puck_pos

    # now we segment the images
    for i in range(N):
        print("segmenting image ", i)
        img = imgs[i]
        if segmented:
            dataset[i, :] = segment_func(img,
                                         normalize=False,
                                         **segmentation_kwargs)
            p = np.random.rand(
            )  # manually drop some images, so as to make occlusions
            if p < occlusion_prob:
                mask = (np.random.uniform(low=0, high=1, size=(imsize, imsize))
                        > occlusion_level).astype(np.uint8)
                img = dataset[i].reshape(3, imsize, imsize).transpose()
                img[mask < 1] = 0
                dataset[i] = img.transpose().flatten()

        else:
            dataset[i, :] = unormalize_image(img)

    # add the trajectory dimension
    dataset = dataset[:, np.newaxis, :]  # batch_size x traj_len = 1 x imlen
    obj_states = obj_states[:,
                            np.newaxis, :]  # batch_size x traj_len = 1 x imlen
    info['obj_state'] = obj_states

    n = int(N * test_p)
    train_dataset = dataset[:n]
    test_dataset = dataset[n:]

    if N >= 500:
        print('save data to: ', data_file_path)
        all_data = np.concatenate([train_dataset, test_dataset], axis=0)
        np.save(data_file_path, all_data)
        np.save(obj_state_path, obj_states)

    return train_dataset, test_dataset, info
示例#15
0
def train_pixelcnn(
    vqvae=None,
    vqvae_path=None,
    num_epochs=100,
    batch_size=32,
    n_layers=15,
    dataset_path=None,
    save=True,
    save_period=10,
    cached_dataset_path=False,
    trainer_kwargs=None,
    model_kwargs=None,
    data_filter_fn=lambda x: x,
    debug=False,
    data_size=float('inf'),
    num_train_batches_per_epoch=None,
    num_test_batches_per_epoch=None,
    train_img_loader=None,
    test_img_loader=None,
):
    trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs
    model_kwargs = {} if model_kwargs is None else model_kwargs

    # Load VQVAE + Define Args
    if vqvae is None:
        vqvae = load_local_or_remote_file(vqvae_path)
        vqvae.to(ptu.device)
        vqvae.eval()

    root_len = vqvae.root_len
    num_embeddings = vqvae.num_embeddings
    embedding_dim = vqvae.embedding_dim
    cond_size = vqvae.num_embeddings
    imsize = vqvae.imsize
    discrete_size = root_len * root_len
    representation_size = embedding_dim * discrete_size
    input_channels = vqvae.input_channels
    imlength = imsize * imsize * input_channels

    log_dir = logger.get_snapshot_dir()

    # Define data loading info
    new_path = osp.join(log_dir, 'pixelcnn_data.npy')

    def prep_sample_data(cached_path):
        data = load_local_or_remote_file(cached_path).item()
        train_data = data['train']
        test_data = data['test']
        return train_data, test_data

    def encode_dataset(path, object_list):
        data = load_local_or_remote_file(path)
        data = data.item()
        data = data_filter_fn(data)

        all_data = []
        n = min(data["observations"].shape[0], data_size)

        for i in tqdm(range(n)):
            obs = ptu.from_numpy(data["observations"][i] / 255.0)
            latent = vqvae.encode(obs, cont=False)
            all_data.append(latent)

        encodings = ptu.get_numpy(torch.stack(all_data, dim=0))
        return encodings

    if train_img_loader:
        _, test_loader, test_batch_loader = create_conditional_data_loader(
            test_img_loader, 80, vqvae, "test2")  # 80
        _, train_loader, train_batch_loader = create_conditional_data_loader(
            train_img_loader, 2000, vqvae, "train2")  # 2000
    else:
        if cached_dataset_path:
            train_data, test_data = prep_sample_data(cached_dataset_path)
        else:
            train_data = encode_dataset(dataset_path['train'],
                                        None)  # object_list)
            test_data = encode_dataset(dataset_path['test'], None)
        dataset = {'train': train_data, 'test': test_data}
        np.save(new_path, dataset)

        _, _, train_loader, test_loader, _ = \
            rlkit.torch.vae.pixelcnn_utils.load_data_and_data_loaders(new_path, 'COND_LATENT_BLOCK', batch_size)

    #train_dataset = InfiniteBatchLoader(train_loader)
    #test_dataset = InfiniteBatchLoader(test_loader)

    print("Finished loading data")

    model = GatedPixelCNN(num_embeddings,
                          root_len**2,
                          n_classes=representation_size,
                          **model_kwargs).to(ptu.device)
    trainer = PixelCNNTrainer(
        model,
        vqvae,
        batch_size=batch_size,
        **trainer_kwargs,
    )

    print("Starting training")

    BEST_LOSS = 999
    for epoch in range(num_epochs):
        should_save = (epoch % save_period == 0) and (epoch > 0)
        trainer.train_epoch(epoch, train_loader, num_train_batches_per_epoch)
        trainer.test_epoch(epoch, test_loader, num_test_batches_per_epoch)

        test_data = test_batch_loader.random_batch(bz)["x"]
        train_data = train_batch_loader.random_batch(bz)["x"]
        trainer.dump_samples(epoch, test_data, test=True)
        trainer.dump_samples(epoch, train_data, test=False)

        if should_save:
            logger.save_itr_params(epoch, model)

        stats = trainer.get_diagnostics()

        cur_loss = stats["test/loss"]
        if cur_loss < BEST_LOSS:
            BEST_LOSS = cur_loss
            vqvae.set_pixel_cnn(model)
            logger.save_extra_data(vqvae, 'best_vqvae', mode='torch')
        else:
            return vqvae

        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

    return vqvae
def generate_vae_dataset(variant):
    """
    If not provided a pre-train vae dataset generation function, this function will be used to collect
    the dataset for training vae.
    """
    import rlkit.torch.pytorch_util as ptu
    import gym
    import multiworld
    multiworld.register_all_envs()

    print("generating vae dataset with original images")

    env_class = variant.get('env_class', None)
    env_kwargs = variant.get('env_kwargs', None)
    env_id = variant.get('env_id', None)
    N = variant.get('N', 10000)
    test_p = variant.get('test_p', 0.9)
    use_cached = variant.get('use_cached', True)
    imsize = variant.get('imsize', 84)
    num_channels = variant.get('num_channels', 3)
    show = variant.get('show', False)
    init_camera = variant.get('init_camera', None)
    dataset_path = variant.get('dataset_path', None)
    oracle_dataset_using_set_to_goal = variant.get(
        'oracle_dataset_using_set_to_goal', False)
    random_rollout_data = variant.get('random_rollout_data', False)
    random_and_oracle_policy_data = variant.get(
        'random_and_oracle_policy_data', False)
    random_and_oracle_policy_data_split = variant.get(
        'random_and_oracle_policy_data_split', 0)
    policy_file = variant.get('policy_file', None)
    n_random_steps = variant.get('n_random_steps', 100)
    vae_dataset_specific_env_kwargs = variant.get(
        'vae_dataset_specific_env_kwargs', None)
    save_file_prefix = variant.get('save_file_prefix', None)
    non_presampled_goal_img_is_garbage = variant.get(
        'non_presampled_goal_img_is_garbage', None)
    tag = variant.get('tag', '')

    info = {}
    if dataset_path is not None:
        print('load vae training dataset from: ', dataset_path)
        pjhome = os.environ['PJHOME']
        dataset = np.load(osp.join(pjhome, dataset_path),
                          allow_pickle=True).item()
        if isinstance(dataset, dict):
            dataset = dataset['image_desired_goal']
        dataset = unormalize_image(dataset)
        N = dataset.shape[0]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        if save_file_prefix is None:
            save_file_prefix = env_id
        if save_file_prefix is None:
            save_file_prefix = env_class.__name__
        filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format(
            save_file_prefix,
            str(N),
            init_camera.__name__ if init_camera else '',
            imsize,
            random_and_oracle_policy_data_split,
            tag,
        )
        if use_cached and osp.isfile(filename):
            dataset = np.load(filename)
            print("loaded data from saved file", filename)
        else:
            now = time.time()

            if env_id is not None:
                import gym
                import multiworld
                multiworld.register_all_envs()
                env = gym.make(env_id)
            else:
                if vae_dataset_specific_env_kwargs is None:
                    vae_dataset_specific_env_kwargs = {}
                for key, val in env_kwargs.items():
                    if key not in vae_dataset_specific_env_kwargs:
                        vae_dataset_specific_env_kwargs[key] = val
                env = env_class(**vae_dataset_specific_env_kwargs)
            if not isinstance(env, ImageEnv):
                env = ImageEnv(
                    env,
                    imsize,
                    init_camera=init_camera,
                    transpose=True,
                    normalize=True,
                    non_presampled_goal_img_is_garbage=
                    non_presampled_goal_img_is_garbage,
                )
            else:
                imsize = env.imsize
                env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage
            env.reset()
            info['env'] = env
            if random_and_oracle_policy_data:
                policy_file = load_local_or_remote_file(policy_file)
                policy = policy_file['policy']
                policy.to(ptu.device)
            if random_rollout_data:
                from rlkit.exploration_strategies.ou_strategy import OUStrategy
                policy = OUStrategy(env.action_space)

            dataset = np.zeros((N, imsize * imsize * num_channels),
                               dtype=np.uint8)

            for i in range(N):
                if random_and_oracle_policy_data:
                    num_random_steps = int(N *
                                           random_and_oracle_policy_data_split)
                    if i < num_random_steps:
                        env.reset()
                        for _ in range(n_random_steps):
                            obs = env.step(env.action_space.sample())[0]
                    else:
                        obs = env.reset()
                        policy.reset()
                        for _ in range(n_random_steps):
                            policy_obs = np.hstack((
                                obs['state_observation'],
                                obs['state_desired_goal'],
                            ))
                            action, _ = policy.get_action(policy_obs)
                            obs, _, _, _ = env.step(action)
                elif oracle_dataset_using_set_to_goal:
                    print(i)
                    goal = env.sample_goal()
                    env.set_to_goal(goal)
                    obs = env._get_obs()

                elif random_rollout_data:
                    if i % n_random_steps == 0:
                        g = dict(
                            state_desired_goal=env.sample_goal_for_rollout())
                        env.set_to_goal(g)
                        policy.reset()
                        # env.reset()
                    u = policy.get_action_from_raw_action(
                        env.action_space.sample())
                    obs = env.step(u)[0]
                else:
                    print("using totally random rollouts")
                    for _ in range(n_random_steps):
                        obs = env.step(env.action_space.sample())[0]

                img = obs[
                    'image_observation']  # NOTE yufei: this is already normalized image, of detype np.float64.

                dataset[i, :] = unormalize_image(img)

            np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
示例#17
0
 def prep_sample_data(cached_path):
     data = load_local_or_remote_file(cached_path).item()
     train_data = data['train']
     test_data = data['test']
     return train_data, test_data
示例#18
0
def pearl_experiment(
    qf_kwargs=None,
    vf_kwargs=None,
    trainer_kwargs=None,
    algo_kwargs=None,
    context_encoder_kwargs=None,
    context_decoder_kwargs=None,
    policy_kwargs=None,
    env_name=None,
    env_params=None,
    latent_dim=None,
    # video/debug
    debug=False,
    _debug_do_not_sqrt=False,
    # PEARL
    n_train_tasks=0,
    n_eval_tasks=0,
    use_next_obs_in_context=False,
    saved_tasks_path=None,
    tags=None,
):
    del tags
    register_pearl_envs()
    env_params = env_params or {}
    context_encoder_kwargs = context_encoder_kwargs or {}
    context_decoder_kwargs = context_decoder_kwargs or {}
    trainer_kwargs = trainer_kwargs or {}
    base_env = ENVS[env_name](**env_params)
    if saved_tasks_path:
        task_data = load_local_or_remote_file(saved_tasks_path,
                                              file_type='joblib')
        tasks = task_data['tasks']
        train_task_idxs = task_data['train_task_indices']
        eval_task_idxs = task_data['eval_task_indices']
        base_env.tasks = tasks
    else:
        tasks = base_env.tasks
        task_indices = base_env.get_all_task_idx()
        train_task_idxs = list(task_indices[:n_train_tasks])
        eval_task_idxs = list(task_indices[-n_eval_tasks:])
    if hasattr(base_env, 'task_to_vec'):
        train_tasks = [base_env.task_to_vec(tasks[i]) for i in train_task_idxs]
        eval_tasks = [base_env.task_to_vec(tasks[i]) for i in eval_task_idxs]
    else:
        train_tasks = [tasks[i] for i in train_task_idxs]
        eval_tasks = [tasks[i] for i in eval_task_idxs]
    expl_env = NormalizedBoxEnv(base_env)
    eval_env = NormalizedBoxEnv(ENVS[env_name](**env_params))
    eval_env.tasks = expl_env.tasks
    reward_dim = 1

    if debug:
        algo_kwargs['max_path_length'] = 50
        algo_kwargs['batch_size'] = 5
        algo_kwargs['num_epochs'] = 5
        algo_kwargs['num_eval_steps_per_epoch'] = 100
        algo_kwargs['num_expl_steps_per_train_loop'] = 100
        algo_kwargs['num_trains_per_train_loop'] = 10
        algo_kwargs['min_num_steps_before_training'] = 100

    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    if use_next_obs_in_context:
        context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim
    else:
        context_encoder_input_dim = obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim + latent_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    vf = ConcatMlp(input_size=obs_dim + latent_dim, output_size=1, **vf_kwargs)

    policy = TanhGaussianPolicy(
        obs_dim=obs_dim + latent_dim,
        action_dim=action_dim,
        **policy_kwargs,
    )
    context_encoder = MlpEncoder(input_size=context_encoder_input_dim,
                                 output_size=context_encoder_output_dim,
                                 hidden_sizes=[200, 200, 200],
                                 **context_encoder_kwargs)
    context_decoder = MlpDecoder(input_size=obs_dim + action_dim + latent_dim,
                                 output_size=1,
                                 **context_decoder_kwargs)
    reward_predictor = context_decoder
    agent = SmacAgent(
        latent_dim,
        context_encoder,
        policy,
        reward_predictor,
        use_next_obs_in_context=use_next_obs_in_context,
        _debug_do_not_sqrt=_debug_do_not_sqrt,
    )
    trainer = PEARLSoftActorCriticTrainer(latent_dim=latent_dim,
                                          agent=agent,
                                          qf1=qf1,
                                          qf2=qf2,
                                          vf=vf,
                                          reward_predictor=reward_predictor,
                                          context_encoder=context_encoder,
                                          context_decoder=context_decoder,
                                          **trainer_kwargs)
    algorithm = MetaRLAlgorithm(
        agent=agent,
        env=expl_env,
        trainer=trainer,
        train_task_indices=train_task_idxs,
        eval_task_indices=eval_task_idxs,
        train_tasks=train_tasks,
        eval_tasks=eval_tasks,
        use_next_obs_in_context=use_next_obs_in_context,
        env_info_sizes=get_env_info_sizes(expl_env),
        **algo_kwargs)
    saved_path = logger.save_extra_data(
        data=dict(
            tasks=expl_env.tasks,
            train_task_indices=train_task_idxs,
            eval_task_indices=eval_task_idxs,
            train_tasks=train_tasks,
            eval_tasks=eval_tasks,
        ),
        file_name='tasks_description',
    )
    print('saved tasks description to', saved_path)
    saved_path = logger.save_extra_data(
        expl_env.tasks,
        file_name='tasks',
        mode='pickle',
    )
    print('saved raw tasks to', saved_path)

    algorithm.to(ptu.device)

    algorithm.to(ptu.device)
    algorithm.train()
示例#19
0
def smac_experiment(
    trainer_kwargs=None,
    algo_kwargs=None,
    qf_kwargs=None,
    policy_kwargs=None,
    context_encoder_kwargs=None,
    context_decoder_kwargs=None,
    env_name=None,
    env_params=None,
    path_loader_kwargs=None,
    latent_dim=None,
    policy_class="TanhGaussianPolicy",
    # video/debug
    debug=False,
    use_dummy_encoder=False,
    networks_ignore_context=False,
    use_ground_truth_context=False,
    save_video=False,
    save_video_period=False,
    # Pre-train params
    pretrain_rl=False,
    pretrain_offline_algo_kwargs=None,
    pretrain_buffer_kwargs=None,
    load_buffer_kwargs=None,
    saved_tasks_path=None,
    macaw_format_base_path=None,  # overrides saved_tasks_path and load_buffer_kwargs
    load_macaw_buffer_kwargs=None,
    train_task_idxs=None,
    eval_task_idxs=None,
    relabel_offline_dataset=False,
    skip_initial_data_collection_if_pretrained=False,
    relabel_kwargs=None,
    # PEARL
    n_train_tasks=0,
    n_eval_tasks=0,
    use_next_obs_in_context=False,
    tags=None,
    online_trainer_kwargs=None,
):
    if not skip_initial_data_collection_if_pretrained:
        raise NotImplementedError("deprecated! make sure to skip it!")
    if relabel_kwargs is None:
        relabel_kwargs = {}
    del tags
    pretrain_buffer_kwargs = pretrain_buffer_kwargs or {}
    context_decoder_kwargs = context_decoder_kwargs or {}
    pretrain_offline_algo_kwargs = pretrain_offline_algo_kwargs or {}
    online_trainer_kwargs = online_trainer_kwargs or {}
    register_pearl_envs()
    env_params = env_params or {}
    context_encoder_kwargs = context_encoder_kwargs or {}
    trainer_kwargs = trainer_kwargs or {}
    path_loader_kwargs = path_loader_kwargs or {}
    load_macaw_buffer_kwargs = load_macaw_buffer_kwargs or {}

    base_env = ENVS[env_name](**env_params)
    if saved_tasks_path:
        task_data = load_local_or_remote_file(saved_tasks_path,
                                              file_type='joblib')
        tasks = task_data['tasks']
        train_task_idxs = task_data['train_task_indices']
        eval_task_idxs = task_data['eval_task_indices']
        base_env.tasks = tasks
    elif macaw_format_base_path is not None:
        tasks = pickle.load(
            open('{}/tasks.pkl'.format(macaw_format_base_path), 'rb'))
        base_env.tasks = tasks
    else:
        tasks = base_env.tasks
        task_indices = base_env.get_all_task_idx()
        train_task_idxs = list(task_indices[:n_train_tasks])
        eval_task_idxs = list(task_indices[-n_eval_tasks:])
    if hasattr(base_env, 'task_to_vec'):
        train_tasks = [base_env.task_to_vec(tasks[i]) for i in train_task_idxs]
        eval_tasks = [base_env.task_to_vec(tasks[i]) for i in eval_task_idxs]
    else:
        train_tasks = [tasks[i] for i in train_task_idxs]
        eval_tasks = [tasks[i] for i in eval_task_idxs]
    if use_ground_truth_context:
        latent_dim = len(train_tasks[0])
    expl_env = NormalizedBoxEnv(base_env)

    reward_dim = 1

    if debug:
        algo_kwargs['max_path_length'] = 50
        algo_kwargs['batch_size'] = 5
        algo_kwargs['num_epochs'] = 5
        algo_kwargs['num_eval_steps_per_epoch'] = 100
        algo_kwargs['num_expl_steps_per_train_loop'] = 100
        algo_kwargs['num_trains_per_train_loop'] = 10
        algo_kwargs['min_num_steps_before_training'] = 100

    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size

    if use_next_obs_in_context:
        context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim
    else:
        context_encoder_input_dim = obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim + latent_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    if isinstance(policy_class, str):
        policy_class = policy_class_from_str(policy_class)
    policy = policy_class(
        obs_dim=obs_dim + latent_dim,
        action_dim=action_dim,
        **policy_kwargs,
    )
    encoder_class = DummyMlpEncoder if use_dummy_encoder else MlpEncoder
    context_encoder = encoder_class(
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
        hidden_sizes=[200, 200, 200],
        use_ground_truth_context=use_ground_truth_context,
        **context_encoder_kwargs)
    context_decoder = MlpDecoder(input_size=obs_dim + action_dim + latent_dim,
                                 output_size=1,
                                 **context_decoder_kwargs)
    reward_predictor = context_decoder
    agent = SmacAgent(
        latent_dim,
        context_encoder,
        policy,
        reward_predictor,
        use_next_obs_in_context=use_next_obs_in_context,
        _debug_ignore_context=networks_ignore_context,
        _debug_use_ground_truth_context=use_ground_truth_context,
    )
    trainer = SmacTrainer(
        agent=agent,
        env=expl_env,
        latent_dim=latent_dim,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        reward_predictor=reward_predictor,
        context_encoder=context_encoder,
        context_decoder=context_decoder,
        _debug_ignore_context=networks_ignore_context,
        _debug_use_ground_truth_context=use_ground_truth_context,
        **trainer_kwargs)
    algorithm = MetaRLAlgorithm(
        agent=agent,
        env=expl_env,
        trainer=trainer,
        train_task_indices=train_task_idxs,
        eval_task_indices=eval_task_idxs,
        train_tasks=train_tasks,
        eval_tasks=eval_tasks,
        use_next_obs_in_context=use_next_obs_in_context,
        use_ground_truth_context=use_ground_truth_context,
        env_info_sizes=get_env_info_sizes(expl_env),
        **algo_kwargs)

    if macaw_format_base_path:
        load_macaw_buffer_onto_algo(algo=algorithm,
                                    base_directory=macaw_format_base_path,
                                    train_task_idxs=train_task_idxs,
                                    **load_macaw_buffer_kwargs)
    elif load_buffer_kwargs:
        load_buffer_onto_algo(algorithm, **load_buffer_kwargs)
    if relabel_offline_dataset:
        relabel_offline_data(algorithm,
                             tasks=tasks,
                             env=expl_env.wrapped_env,
                             **relabel_kwargs)
    if path_loader_kwargs:
        replay_buffer = algorithm.replay_buffer.task_buffers[0]
        enc_replay_buffer = algorithm.enc_replay_buffer.task_buffers[0]
        demo_test_buffer = EnvReplayBuffer(env=expl_env,
                                           **pretrain_buffer_kwargs)
        path_loader = MDPPathLoader(trainer,
                                    replay_buffer=replay_buffer,
                                    demo_train_buffer=enc_replay_buffer,
                                    demo_test_buffer=demo_test_buffer,
                                    **path_loader_kwargs)
        path_loader.load_demos()

    if pretrain_rl:
        eval_pearl_fn = EvalPearl(algorithm, train_task_idxs, eval_task_idxs)
        pretrain_algo = OfflineMetaRLAlgorithm(
            meta_replay_buffer=algorithm.meta_replay_buffer,
            replay_buffer=algorithm.replay_buffer,
            task_embedding_replay_buffer=algorithm.enc_replay_buffer,
            trainer=trainer,
            train_tasks=train_task_idxs,
            extra_eval_fns=[eval_pearl_fn],
            use_meta_learning_buffer=algorithm.use_meta_learning_buffer,
            **pretrain_offline_algo_kwargs)
        pretrain_algo.to(ptu.device)
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain.csv',
                                  relative_to_snapshot_dir=True)
        pretrain_algo.train()
        logger.remove_tabular_output('pretrain.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        if skip_initial_data_collection_if_pretrained:
            algorithm.num_initial_steps = 0

    algorithm.trainer.configure(**online_trainer_kwargs)
    algorithm.to(ptu.device)
    algorithm.train()
示例#20
0
def generate_vae_dataset(cfgs):
    env_id = cfgs.ENV.id
    img_size = cfgs.ENV.imsize
    init_camera = cfgs.ENV.init_camera

    N = cfgs.GENERATE_VAE_DATASET.N
    use_cached = cfgs.GENERATE_VAE_DATASET.use_cached
    n_random_steps = cfgs.GENERATE_VAE_DATASET.n_random_steps
    dataset_path = cfgs.GENERATE_VAE_DATASET.dataset_path  # FIXME
    non_presampled_goal_img_is_garbage = cfgs.GENERATE_VAE_DATASET.non_presampled_goal_img_is_garbage
    random_and_oracle_policy_data_split = cfgs.GENERATE_VAE_DATASET.random_and_oracle_policy_data_split
    random_and_oracle_policy_data = cfgs.GENERATE_VAE_DATASET.random_and_oracle_policy_data
    random_rollout_data = cfgs.GENERATE_VAE_DATASET.random_rollout_data
    oracle_dataset_using_set_to_goal = cfgs.GENERATE_VAE_DATASET.oracle_dataset_using_set_to_goal

    num_channels = cfgs.VAE.input_channels
    policy_file = cfgs.POLICY.model_path

    from roworld.core.image_env import ImageEnv, unormalize_image
    import rlkit.torch.pytorch_util as ptu

    info = {}
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        N = dataset.shape[0]
    else:
        filename = "/tmp/{}_N{}_{}_size{}_random_oracle_split_{}.npy".format(
            env_id,
            str(N),
            init_camera.__name__ if init_camera else '',
            img_size,
            random_and_oracle_policy_data_split,
        )
        if use_cached and osp.isfile(filename):
            dataset = np.load(filename)
            print("loaded data from saved file", filename)
        else:
            now = time.time()

            assert env_id is not None
            import gym
            import roworld
            roworld.register_all_envs()
            env = gym.make(env_id)

            if not isinstance(env, ImageEnv):
                env = ImageEnv(
                    env,
                    img_size,
                    init_camera=init_camera,
                    transpose=True,
                    normalize=True,
                    non_presampled_goal_img_is_garbage=
                    non_presampled_goal_img_is_garbage,
                )
            else:
                env.imsize = img_size
                env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage

            env.reset()
            info['env'] = env
            if random_and_oracle_policy_data:
                policy_file = load_local_or_remote_file(policy_file)
                policy = policy_file['policy']
                policy.to(ptu.device)
            if random_rollout_data:
                from rlkit.exploration_strategies.ou_strategy import OUStrategy
                policy = OUStrategy(env.action_space)

            dataset = np.zeros((N, img_size * img_size * num_channels),
                               dtype=np.uint8)
            obs = env.reset()
            for i in range(N):
                if random_and_oracle_policy_data:
                    num_random_steps = int(N *
                                           random_and_oracle_policy_data_split)
                    if i < num_random_steps:
                        # Randomly obtain observation
                        env.reset()
                        for _ in range(n_random_steps):
                            obs = env.step(env.action_space.sample())[0]
                    else:
                        # Obtain observation with policy
                        obs = env.reset()
                        policy.reset()
                        for _ in range(n_random_steps):
                            policy_obs = np.hstack((
                                obs['state_observation'],
                                obs['state_desired_goal'],
                            ))
                            action, _ = policy.get_action(policy_obs)
                            obs, _, _, _ = env.step(action)
                elif oracle_dataset_using_set_to_goal:
                    goal = env.sample_goal()
                    env.set_to_goal(goal)
                    obs = env._get_obs()
                elif random_rollout_data:
                    if i % n_random_steps == 0:
                        g = dict(
                            state_desired_goal=env.sample_goal_for_rollout())
                        env.set_to_goal(g)
                        policy.reset()
                        # env.reset()
                    u = policy.get_action_from_raw_action(
                        env.action_space.sample())
                    obs = env.step(u)[0]
                else:
                    env.reset()
                    # The output obs will be the last observation after stepping n_random_steps
                    for _ in range(n_random_steps):
                        obs = env.step(env.action_space.sample())[0]

                img = obs['image_observation']
                dataset[i, :] = unormalize_image(img)

                if cfgs.GENERATE_VAE_DATASET.show:
                    img = img.reshape(3, img_size, img_size).transpose()
                    img = img[::-1, :, ::-1]
                    cv2.imshow('img', img)
                    cv2.waitKey(1)
                    # radius = input('waiting...')
            print("Done making training data", filename, time.time() - now)
            np.save(filename, dataset)

    n = int(N * cfgs.GENERATE_VAE_DATASET.ratio)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
示例#21
0
def get_envs(cfgs):
    from roworld.core.image_env import ImageEnv
    from rlkit.envs.vae_wrapper import VAEWrappedEnv
    from rlkit.util.io import load_local_or_remote_file

    render = cfgs.get('render', False)
    reward_params = cfgs.get("reward_params", dict())
    do_state_exp = cfgs.get("do_state_exp", False)  # TODO

    vae_path = cfgs.VAE_TRAINER.get("vae_path", None)
    init_camera = cfgs.ENV.get("init_camera", None)

    presample_goals = cfgs.SKEW_FIT.get('presample_goals', False)
    presample_image_goals_only = cfgs.SKEW_FIT.get(
        'presample_image_goals_only', False)
    presampled_goals_path = cfgs.SKEW_FIT.get('presampled_goals_path', None)

    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if cfgs.ENV.id:
        import gym
        import roworld
        roworld.register_all_envs()
        env = gym.make(cfgs.ENV.id)
    else:
        env = cfgs.ENV.cls(**cfgs.ENV.kwargs)

    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                cfgs.ENV.imsize,
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                vae_env = VAEWrappedEnv(image_env,
                                        vae,
                                        imsize=image_env.imsize,
                                        decode_goals=render,
                                        render_goals=render,
                                        render_rollouts=render,
                                        reward_params=reward_params,
                                        **cfgs.get('vae_wrapped_env_kwargs',
                                                   {}))
                presampled_goals = cfgs['generate_goal_dataset_fctn'](
                    env=vae_env,
                    env_id=cfgs.get('env_id', None),
                    **cfgs['goal_generation_kwargs'])
                del vae_env
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path).item()
            del image_env
            image_env = ImageEnv(
                env,
                cfgs.ENV.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
                presampled_goals=presampled_goals,
            )
            vae_env = VAEWrappedEnv(
                image_env,
                vae,
                imsize=image_env.imsize,
                decode_goals=render,
                render_goals=render,
                render_rollouts=render,
                reward_params=reward_params,
                presampled_goals=presampled_goals,
                sample_from_true_prior=True,
            )
            print("Pre sampling all goals only")
        else:
            vae_env = VAEWrappedEnv(
                image_env,
                vae,
                imsize=image_env.imsize,
                decode_goals=render,
                render_goals=render,
                render_rollouts=render,
                reward_params=reward_params,
                goal_sampling_mode='vae_prior',
                sample_from_true_prior=True,
            )
            if presample_image_goals_only:
                presampled_goals = cfgs['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **cfgs['goal_generation_kwargs'])
                image_env.set_presampled_goals(presampled_goals)
                print("Pre sampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env
    return env
示例#22
0
 def __init__(self, filename):
     self.data = load_local_or_remote_file(filename)
def generate_vae_dataset(variant):
    env_class = variant.get('env_class', None)
    env_kwargs = variant.get('env_kwargs', None)
    env_id = variant.get('env_id', None)
    N = variant.get('N', 10000)
    test_p = variant.get('test_p', 0.9)
    use_cached = variant.get('use_cached', True)
    imsize = variant.get('imsize', 84)
    num_channels = variant.get('num_channels', 3)
    show = variant.get('show', False)
    init_camera = variant.get('init_camera', None)
    dataset_path = variant.get('dataset_path', None)
    oracle_dataset_using_set_to_goal = variant.get(
        'oracle_dataset_using_set_to_goal', False)
    random_rollout_data = variant.get('random_rollout_data', False)
    random_and_oracle_policy_data = variant.get(
        'random_and_oracle_policy_data', False)
    random_and_oracle_policy_data_split = variant.get(
        'random_and_oracle_policy_data_split', 0)
    policy_file = variant.get('policy_file', None)
    n_random_steps = variant.get('n_random_steps', 100)
    vae_dataset_specific_env_kwargs = variant.get(
        'vae_dataset_specific_env_kwargs', None)
    save_file_prefix = variant.get('save_file_prefix', None)
    non_presampled_goal_img_is_garbage = variant.get(
        'non_presampled_goal_img_is_garbage', None)
    tag = variant.get('tag', '')
    from multiworld.core.image_env import ImageEnv, unormalize_image
    import rlkit.torch.pytorch_util as ptu
    info = {}
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        N = dataset.shape[0]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        if save_file_prefix is None:
            save_file_prefix = env_id
        if save_file_prefix is None:
            save_file_prefix = env_class.__name__
        filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format(
            save_file_prefix,
            str(N),
            init_camera.__name__ if init_camera else '',
            imsize,
            random_and_oracle_policy_data_split,
            tag,
        )
        if use_cached and osp.isfile(filename):
            dataset = np.load(filename)
            print("loaded data from saved file", filename)
        else:
            now = time.time()

            if env_id is not None:
                import gym
                import multiworld
                multiworld.register_all_envs()
                env = gym.make(env_id)
            else:
                if vae_dataset_specific_env_kwargs is None:
                    vae_dataset_specific_env_kwargs = {}
                for key, val in env_kwargs.items():
                    if key not in vae_dataset_specific_env_kwargs:
                        vae_dataset_specific_env_kwargs[key] = val
                env = env_class(**vae_dataset_specific_env_kwargs)
            if not isinstance(env, ImageEnv):
                env = ImageEnv(
                    env,
                    imsize,
                    init_camera=init_camera,
                    transpose=True,
                    normalize=True,
                    non_presampled_goal_img_is_garbage=
                    non_presampled_goal_img_is_garbage,
                )
            else:
                imsize = env.imsize
                env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage
            env.reset()
            info['env'] = env
            if random_and_oracle_policy_data:
                policy_file = load_local_or_remote_file(policy_file)
                policy = policy_file['policy']
                policy.to(ptu.device)
            if random_rollout_data:
                from rlkit.exploration_strategies.ou_strategy import OUStrategy
                policy = OUStrategy(env.action_space)
            dataset = np.zeros((N, imsize * imsize * num_channels),
                               dtype=np.uint8)
            for i in range(N):
                if random_and_oracle_policy_data:
                    num_random_steps = int(N *
                                           random_and_oracle_policy_data_split)
                    if i < num_random_steps:
                        env.reset()
                        for _ in range(n_random_steps):
                            obs = env.step(env.action_space.sample())[0]
                    else:
                        obs = env.reset()
                        policy.reset()
                        for _ in range(n_random_steps):
                            policy_obs = np.hstack((
                                obs['state_observation'],
                                obs['state_desired_goal'],
                            ))
                            action, _ = policy.get_action(policy_obs)
                            obs, _, _, _ = env.step(action)
                elif oracle_dataset_using_set_to_goal:
                    print(i)
                    goal = env.sample_goal()
                    env.set_to_goal(goal)
                    obs = env._get_obs()
                elif random_rollout_data:
                    if i % n_random_steps == 0:
                        g = dict(
                            state_desired_goal=env.sample_goal_for_rollout())
                        env.set_to_goal(g)
                        policy.reset()
                        # env.reset()
                    u = policy.get_action_from_raw_action(
                        env.action_space.sample())
                    obs = env.step(u)[0]
                else:
                    env.reset()
                    for _ in range(n_random_steps):
                        obs = env.step(env.action_space.sample())[0]
                img = obs['image_observation']
                dataset[i, :] = unormalize_image(img)
                if show:
                    img = img.reshape(3, imsize, imsize).transpose()
                    img = img[::-1, :, ::-1]
                    cv2.imshow('img', img)
                    cv2.waitKey(1)
                    # radius = input('waiting...')
            print("done making training data", filename, time.time() - now)
            np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
示例#24
0
def getdata(variant):
    skewfit_variant = variant['skewfit_variant']
    print('-------------------------------')
    skewfit_preprocess_variant(skewfit_variant)
    skewfit_variant['render'] = True
    vae_environment = get_envs(skewfit_variant)
    print('done loading vae_env')

    env_class = variant.get('env_class', None)
    env_kwargs = variant.get('env_kwargs', None)
    env_id = variant.get('env_id', None)
    N = variant.get('N', 10000)
    test_p = variant.get('test_p', 0.9)
    use_cached = variant.get('use_cached', True)
    imsize = variant.get('imsize', 84)
    num_channels = variant.get('num_channels', 3)
    show = variant.get('show', False)
    init_camera = variant.get('init_camera', None)
    dataset_path = variant.get('dataset_path', None)
    oracle_dataset_using_set_to_goal = variant.get(
        'oracle_dataset_using_set_to_goal', False)
    random_rollout_data = variant.get('random_rollout_data', False)
    random_and_oracle_policy_data = variant.get(
        'random_and_oracle_policy_data', False)
    random_and_oracle_policy_data_split = variant.get(
        'random_and_oracle_policy_data_split', 0)
    policy_file = variant.get('policy_file', None)
    n_random_steps = variant.get('n_random_steps', 100)
    vae_dataset_specific_env_kwargs = variant.get(
        'vae_dataset_specific_env_kwargs', None)
    save_file_prefix = variant.get('save_file_prefix', None)
    non_presampled_goal_img_is_garbage = variant.get(
        'non_presampled_goal_img_is_garbage', None)
    tag = variant.get('tag', '')
    from multiworld.core.image_env import ImageEnv, unormalize_image
    import rlkit.torch.pytorch_util as ptu
    info = {}
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        N = dataset.shape[0]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        if save_file_prefix is None:
            save_file_prefix = env_id
        if save_file_prefix is None:
            save_file_prefix = env_class.__name__
        filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format(
            save_file_prefix,
            str(N),
            init_camera.__name__ if init_camera else '',
            imsize,
            random_and_oracle_policy_data_split,
            tag,
        )
        if True:
            now = time.time()

            if env_id is not None:
                import gym
                import multiworld
                multiworld.register_all_envs()
                env = gym.make(env_id)
            else:
                if vae_dataset_specific_env_kwargs is None:
                    vae_dataset_specific_env_kwargs = {}
                for key, val in env_kwargs.items():
                    if key not in vae_dataset_specific_env_kwargs:
                        vae_dataset_specific_env_kwargs[key] = val
                env = env_class(**vae_dataset_specific_env_kwargs)
            if not isinstance(env, ImageEnv):
                print("using(ImageEnv)")
                env = ImageEnv(
                    env,
                    imsize,
                    init_camera=init_camera,
                    transpose=True,
                    normalize=True,
                    non_presampled_goal_img_is_garbage=
                    non_presampled_goal_img_is_garbage,
                )
            else:
                imsize = env.imsize
                env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage
            env.reset()
            info['env'] = env
            if random_and_oracle_policy_data:
                policy_file = load_local_or_remote_file(policy_file)
                policy = policy_file['policy']
                policy.to(ptu.device)
            if random_rollout_data:
                from rlkit.exploration_strategies.ou_strategy import OUStrategy
                policy = OUStrategy(env.action_space)
            dataset = np.zeros((N, imsize * imsize * num_channels),
                               dtype=np.uint8)

            for i in range(10):
                NP = []
                if True:
                    print(i)
                    #print('th step')
                    goal = env.sample_goal()
                    # print("goal___________________________")
                    # print(goal)
                    # print("goal___________________________")
                    env.set_to_goal(goal)
                    obs = env._get_obs()
                    #img = img.reshape(3, imsize, imsize).transpose()
                    # img = img[::-1, :, ::-1]
                    # cv2.imshow('img', img)
                    # cv2.waitKey(1)
                    img_1 = obs['image_observation']
                    img_1 = img_1.reshape(3, imsize, imsize).transpose()
                    NP.append(img_1)
                    if i % 3 == 0:
                        cv2.imshow('img1', img_1)
                        cv2.waitKey(1)
                    #img_1_reconstruct = vae_environment._reconstruct_img(obs['image_observation']).transpose()
                    encoded_1 = vae_environment._get_encoded(
                        obs['image_observation'])
                    print(encoded_1)
                    NP.append(encoded_1)
                    img_1_reconstruct = vae_environment._get_img(
                        encoded_1).transpose()
                    NP.append(img_1_reconstruct)
                    #dataset[i, :] = unormalize_image(img)
                    # img_1 = img_1.reshape(3, imsize, imsize).transpose()
                    if i % 3 == 0:
                        cv2.imshow('img1_reconstruction', img_1_reconstruct)
                        cv2.waitKey(1)
                    env.reset()
                    instr = env.generate_new_state(goal)
                    if i % 3 == 0:
                        print(instr)
                    obs = env._get_obs()
                    # obs = env._get_obs()
                    img_2 = obs['image_observation']
                    img_2 = img_2.reshape(3, imsize, imsize).transpose()
                    NP.append(img_2)
                    if i % 3 == 0:
                        cv2.imshow('img2', img_2)
                        cv2.waitKey(1)
                    #img_2_reconstruct = vae_environment._reconstruct_img(obs['image_observation']).transpose()
                    encoded_2 = vae_environment._get_encoded(
                        obs['image_observation'])
                    NP.append(encoded_2)
                    img_2_reconstruct = vae_environment._get_img(
                        encoded_2).transpose()
                    NP.append(img_2_reconstruct)
                    NP.append(instr)
                    # img_2 = img_2.reshape(3, imsize, imsize).transpose()
                    if i % 3 == 0:
                        cv2.imshow('img2_reconstruct', img_2_reconstruct)
                        cv2.waitKey(1)
                    NP = np.array(NP)
                    idx = str(i)
                    name = "/home/xiaomin/Downloads/IFIG_DATA_1/" + idx + ".npy"
                    np.save(open(name, 'wb'), NP)
                    # radius = input('waiting...')

                # #get the in between functions
            import dill
            import pickle
            get_encoded = dill.dumps(vae_environment._get_encoded)
            with open(
                    "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_encoded_1000_epochs_one_puck.txt",
                    "wb") as fp:
                pickle.dump(get_encoded, fp)
            with open(
                    "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_encoded_1000_epochs_one_puck.txt",
                    "rb") as fp:
                b = pickle.load(fp)
            func_get_encoded = dill.loads(b)
            encoded = func_get_encoded(obs['image_observation'])
            print(encoded)
            print('------------------------------')
            get_img = dill.dumps(vae_environment._get_img)
            with open(
                    "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_img_1000_epochs_one_puck.txt",
                    "wb") as fp:
                pickle.dump(get_img, fp)
            with open(
                    "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_img_1000_epochs_one_puck.txt",
                    "rb") as fp:
                c = pickle.load(fp)
            func_get_img = dill.loads(c)

            img_1_reconstruct = func_get_img(encoded).transpose()
            print(img_1_reconstruct)
            #dataset[i, :] = unormalize_image(img)
            # img_1 = img_1.reshape(3, imsize, imsize).transpose()
            cv2.imshow('test', img_1_reconstruct)
            cv2.waitKey(0)

            print("done making training data", filename, time.time() - now)
            np.save(filename, dataset)
def train_vae_and_update_variant(variant):  # actually pretrain vae and ROLL.
    skewfit_variant = variant['skewfit_variant']
    train_vae_variant = variant['train_vae_variant']

    # prepare the background subtractor needed to perform segmentation
    if 'unet' in skewfit_variant['segmentation_method']:
        print("training opencv background model!")
        v = train_vae_variant['generate_lstm_dataset_kwargs']
        env_id = v.get('env_id', None)
        env_id_invis = invisiable_env_id[env_id]
        import gym
        import multiworld
        multiworld.register_all_envs()
        obj_invisible_env = gym.make(env_id_invis)
        init_camera = v.get('init_camera', None)

        presampled_goals = None
        if skewfit_variant.get("presampled_goals_path") is not None:
            presampled_goals = load_local_or_remote_file(
                skewfit_variant['presampled_goals_path']).item()
            print("presampled goal path is: ",
                  skewfit_variant['presampled_goals_path'])

        obj_invisible_env = ImageEnv(
            obj_invisible_env,
            v.get('imsize'),
            init_camera=init_camera,
            transpose=True,
            normalize=True,
            presampled_goals=presampled_goals,
        )

        train_num = 2000 if 'Push' in env_id else 4000
        train_bgsb(obj_invisible_env, train_num=train_num)

    if skewfit_variant.get('vae_path', None) is None:  # train new vaes
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)

        vaes, vae_train_datas, vae_test_datas = train_vae(
            train_vae_variant,
            skewfit_variant=skewfit_variant,
            return_data=True)  # one original vae, one segmented ROLL.
        if skewfit_variant.get('save_vae_data', False):
            skewfit_variant['vae_train_data'] = vae_train_datas
            skewfit_variant['vae_test_data'] = vae_test_datas

        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        skewfit_variant['vae_path'] = vaes  # just pass the VAE directly
    else:  # load pre-trained vaes
        print("load pretrain scene-/objce-VAE from: {}".format(
            skewfit_variant['vae_path']))
        data = torch.load(osp.join(skewfit_variant['vae_path'], 'params.pkl'))
        vae_original = data['vae_original']
        vae_segmented = data['lstm_segmented']
        skewfit_variant['vae_path'] = [vae_segmented, vae_original]

        generate_vae_dataset_fctn = train_vae_variant.get(
            'generate_vae_data_fctn', generate_vae_dataset)
        generate_lstm_dataset_fctn = train_vae_variant.get(
            'generate_lstm_data_fctn')
        assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!"

        train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn(
            train_vae_variant['generate_lstm_dataset_kwargs'],
            segmented=True,
            segmentation_method=skewfit_variant['segmentation_method'])

        train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn(
            train_vae_variant['generate_vae_dataset_kwargs'])

        train_datas = [train_data_lstm, train_data_ori]
        test_datas = [test_data_lstm, test_data_ori]

        if skewfit_variant.get('save_vae_data', False):
            skewfit_variant['vae_train_data'] = train_datas
            skewfit_variant['vae_test_data'] = test_datas
示例#26
0
    def __init__(
        self,
        wrapped_env,
        vae,
        vae_input_key_prefix='image',
        sample_from_true_prior=False,
        decode_goals=False,
        decode_goals_on_reset=True,
        render_goals=False,
        render_rollouts=False,
        reward_params=None,
        goal_sampling_mode="vae_prior",
        presampled_goals_path=None,
        num_goals_to_presample=0,
        imsize=84,
        obs_size=None,
        norm_order=2,
        epsilon=20,
        presampled_goals=None,
    ):
        if reward_params is None:
            reward_params = dict()
        super().__init__(wrapped_env)
        if type(vae) is str:
            self.vae = load_local_or_remote_file(vae)
        else:
            self.vae = vae
        self.representation_size = self.vae.representation_size
        self.input_channels = self.vae.input_channels
        self.sample_from_true_prior = sample_from_true_prior
        self._decode_goals = decode_goals
        self.render_goals = render_goals
        self.render_rollouts = render_rollouts
        self.default_kwargs = dict(
            decode_goals=decode_goals,
            render_goals=render_goals,
            render_rollouts=render_rollouts,
        )
        self.imsize = imsize
        self.reward_params = reward_params
        self.reward_type = self.reward_params.get("type", 'latent_distance')
        self.norm_order = self.reward_params.get("norm_order", norm_order)
        self.epsilon = self.reward_params.get("epsilon", epsilon)
        self.reward_min_variance = self.reward_params.get("min_variance", 0)
        self.decode_goals_on_reset = decode_goals_on_reset

        latent_space = Box(
            -10 * np.ones(obs_size or self.representation_size),
            10 * np.ones(obs_size or self.representation_size),
            dtype=np.float32,
        )

        spaces = self.wrapped_env.observation_space.spaces
        spaces['observation'] = latent_space
        spaces['desired_goal'] = latent_space
        spaces['achieved_goal'] = latent_space
        spaces['latent_observation'] = latent_space
        spaces['latent_desired_goal'] = latent_space
        spaces['latent_achieved_goal'] = latent_space
        self.observation_space = Dict(spaces)

        self._presampled_goals = presampled_goals

        if num_goals_to_presample > 0:
            self._presampled_goals = self.wrapped_env.sample_goals(
                num_goals_to_presample)
        if presampled_goals_path is not None:
            self._presampled_goals = load_local_or_remote_file(
                presampled_goals_path)

        if self._presampled_goals is None:
            self.num_goals_presampled = 0
        else:
            self.num_goals_presampled = self._presampled_goals[random.choice(
                list(self._presampled_goals))].shape[0]

        self.vae_input_key_prefix = vae_input_key_prefix
        assert vae_input_key_prefix in {'image', 'image_proprio'}
        self.vae_input_observation_key = vae_input_key_prefix + '_observation'
        self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal'
        self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal'
        self._mode_map = {}
        self.desired_goal = {'latent_desired_goal': latent_space.sample()}
        self._initial_obs = None
        self._custom_goal_sampler = None
        self._goal_sampling_mode = goal_sampling_mode
示例#27
0
def get_envs(variant):

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])

    if isinstance(env, ImageEnv):
        image_env = env
    else:
        image_env = ImageEnv(
            env,
            variant.get('imsize'),
            init_camera=init_camera,
            transpose=True,
            normalize=True,
        )
    if presample_goals:
        """
        This will fail for online-parallel as presampled_goals will not be
        serialized. Also don't use this for online-vae.
        """
        if presampled_goals_path is None:
            image_env.non_presampled_goal_img_is_garbage = True
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            presampled_goals = variant['generate_goal_dataset_fctn'](
                env=vae_env,
                env_id=variant.get('env_id', None),
                **variant['goal_generation_kwargs'])
            del vae_env
        else:
            presampled_goals = load_local_or_remote_file(
                presampled_goals_path).item()
        del image_env
        image_env = ImageEnv(env,
                             variant.get('imsize'),
                             init_camera=init_camera,
                             transpose=True,
                             normalize=True,
                             presampled_goals=presampled_goals,
                             **variant.get('image_env_kwargs', {}))
        vae_env = VAEWrappedEnv(image_env,
                                vae,
                                imsize=image_env.imsize,
                                decode_goals=render,
                                render_goals=render,
                                render_rollouts=render,
                                reward_params=reward_params,
                                presampled_goals=presampled_goals,
                                **variant.get('vae_wrapped_env_kwargs', {}))
        print("Presampling all goals only")
    else:
        vae_env = VAEWrappedEnv(image_env,
                                vae,
                                imsize=image_env.imsize,
                                decode_goals=render,
                                render_goals=render,
                                render_rollouts=render,
                                reward_params=reward_params,
                                **variant.get('vae_wrapped_env_kwargs', {}))
        if presample_image_goals_only:
            presampled_goals = variant['generate_goal_dataset_fctn'](
                image_env=vae_env.wrapped_env,
                **variant['goal_generation_kwargs'])
            image_env.set_presampled_goals(presampled_goals)
            print("Presampling image goals only")
        else:
            print("Not using presampled goals")

    env = vae_env

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")
    env.add_mode('eval', testing_mode)
    env.add_mode('train', training_mode)
    env.add_mode('relabeling', training_mode)
    # relabeling_env.disable_render()
    env.add_mode("video_vae", 'video_vae')
    env.add_mode("video_env", 'video_env')
    return env
示例#28
0
    def __init__(
            self,
            wrapped_env,
            vae,
            vae_input_key_prefix='image',
            use_vae_goals=True,
            sample_from_true_prior=False,
            decode_goals=False,
            render_goals=False,
            render_rollouts=False,
            reward_params=None,
            mode="train",
            imsize=84,
            obs_size=None,
            epsilon=20,
            presampled_goals=None,
    ):
        self.quick_init(locals())
        if reward_params is None:
            reward_params = dict()
        super().__init__(wrapped_env)
        if type(vae) is str:
            self.vae = load_local_or_remote_file(vae)
        else:
            self.vae = vae
        self.representation_size = self.vae.representation_size
        self.input_channels = self.vae.input_channels
        self._use_vae_goals = use_vae_goals
        self.sample_from_true_prior = sample_from_true_prior
        self.decode_goals = decode_goals
        self.render_goals = render_goals
        self.render_rollouts = render_rollouts
        self.default_kwargs = dict(
            decode_goals=decode_goals,
            render_goals=render_goals,
            render_rollouts=render_rollouts,
        )
        self.imsize = imsize
        self.reward_params = reward_params
        self.reward_type = self.reward_params.get("type", 'latent_distance')
        self.epsilon = self.reward_params.get("epsilon", epsilon)
        self.reward_min_variance = self.reward_params.get("min_variance", 0)
        latent_space = Box(
            -10 * np.ones(obs_size or self.representation_size),
            10 * np.ones(obs_size or self.representation_size),
            dtype=np.float32,
        )
        spaces = copy.deepcopy(self.wrapped_env.observation_space.spaces)
        spaces['observation'] = latent_space
        spaces['desired_goal'] = latent_space
        spaces['achieved_goal'] = latent_space
        spaces['latent_observation'] = latent_space
        spaces['latent_desired_goal'] = latent_space
        spaces['latent_achieved_goal'] = latent_space
        self.observation_space = Dict(spaces)
        self.mode(mode)
        self._presampled_goals = presampled_goals
        if self._presampled_goals is None:
            self.num_goals_presampled = 0
        else:
            self.num_goals_presampled = (
                presampled_goals[list(presampled_goals)[0]].shape[0]
            )

        self.vae_input_key_prefix = vae_input_key_prefix
        assert vae_input_key_prefix in {'image'}
        self.vae_input_observation_key = vae_input_key_prefix + '_observation'
        self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal'
        self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal'
        self._mode_map = {}
        self.desired_goal = {'latent_desired_goal': latent_space.sample()}
        self._initial_obs = None
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from rlkit.envs.vae_wrapper import VAEWrappedEnv
    from rlkit.util.io import load_local_or_remote_file

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)
    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if 'env_id' in variant:
        import gym
        import multiworld
        multiworld.register_all_envs()
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                vae_env = VAEWrappedEnv(image_env,
                                        vae,
                                        imsize=image_env.imsize,
                                        decode_goals=render,
                                        render_goals=render,
                                        render_rollouts=render,
                                        reward_params=reward_params,
                                        **variant.get('vae_wrapped_env_kwargs',
                                                      {}))
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    env=vae_env,
                    env_id=variant.get('env_id', None),
                    **variant['goal_generation_kwargs'])
                del vae_env
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path).item()
            del image_env
            image_env = ImageEnv(env,
                                 variant.get('imsize'),
                                 init_camera=init_camera,
                                 transpose=True,
                                 normalize=True,
                                 presampled_goals=presampled_goals,
                                 **variant.get('image_env_kwargs', {}))
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    presampled_goals=presampled_goals,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            print("Presampling all goals only")
        else:
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs'])
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    return env
示例#30
0
文件: common.py 项目: anair13/rlkit
def generate_vae_dataset(variant):
    print(variant)
    from tqdm import tqdm
    env_class = variant.get('env_class', None)
    env_kwargs = variant.get('env_kwargs', None)
    env_id = variant.get('env_id', None)
    N = variant.get('N', 10000)
    batch_size = variant.get('batch_size', 128)
    test_p = variant.get('test_p', 0.9)
    use_cached = variant.get('use_cached', True)
    imsize = variant.get('imsize', 84)
    num_channels = variant.get('num_channels', 3)
    show = variant.get('show', False)
    init_camera = variant.get('init_camera', None)
    dataset_path = variant.get('dataset_path', None)
    augment_data = variant.get('augment_data', False)
    data_filter_fn = variant.get('data_filter_fn', lambda x: x)
    delete_after_loading = variant.get('delete_after_loading', False)
    oracle_dataset_using_set_to_goal = variant.get(
        'oracle_dataset_using_set_to_goal', False)
    random_rollout_data = variant.get('random_rollout_data', False)
    random_rollout_data_set_to_goal = variant.get(
        'random_rollout_data_set_to_goal', True)
    random_and_oracle_policy_data = variant.get(
        'random_and_oracle_policy_data', False)
    random_and_oracle_policy_data_split = variant.get(
        'random_and_oracle_policy_data_split', 0)
    policy_file = variant.get('policy_file', None)
    n_random_steps = variant.get('n_random_steps', 100)
    vae_dataset_specific_env_kwargs = variant.get(
        'vae_dataset_specific_env_kwargs', None)
    save_file_prefix = variant.get('save_file_prefix', None)
    non_presampled_goal_img_is_garbage = variant.get(
        'non_presampled_goal_img_is_garbage', None)

    conditional_vae_dataset = variant.get('conditional_vae_dataset', False)
    use_env_labels = variant.get('use_env_labels', False)
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    enviorment_dataset = variant.get('enviorment_dataset', False)
    save_trajectories = variant.get('save_trajectories', False)
    save_trajectories = save_trajectories or use_linear_dynamics or conditional_vae_dataset
    tag = variant.get('tag', '')

    assert N % n_random_steps == 0, "Fix N/horizon or dataset generation will fail"

    from multiworld.core.image_env import ImageEnv, unormalize_image
    import rlkit.torch.pytorch_util as ptu
    from rlkit.util.io import load_local_or_remote_file
    from rlkit.data_management.dataset import (
        TrajectoryDataset, ImageObservationDataset, InitialObservationDataset,
        EnvironmentDataset, ConditionalDynamicsDataset,
        InitialObservationNumpyDataset, InfiniteBatchLoader,
        InitialObservationNumpyJitteringDataset)
    info = {}
    use_test_dataset = False
    if dataset_path is not None:
        if type(dataset_path) == str:
            dataset = load_local_or_remote_file(
                dataset_path, delete_after_loading=delete_after_loading)
            dataset = dataset.item()
            N = dataset['observations'].shape[0] * dataset[
                'observations'].shape[1]
            n_random_steps = dataset['observations'].shape[1]
        if isinstance(dataset_path, list):
            dataset = concatenate_datasets(dataset_path)
            N = dataset['observations'].shape[0] * dataset[
                'observations'].shape[1]
            n_random_steps = dataset['observations'].shape[1]
        if isinstance(dataset_path, dict):

            if type(dataset_path['train']) == str:
                dataset = load_local_or_remote_file(
                    dataset_path['train'],
                    delete_after_loading=delete_after_loading)
                dataset = dataset.item()
            elif isinstance(dataset_path['train'], list):
                dataset = concatenate_datasets(dataset_path['train'])

            if type(dataset_path['test']) == str:
                test_dataset = load_local_or_remote_file(
                    dataset_path['test'],
                    delete_after_loading=delete_after_loading)
                test_dataset = test_dataset.item()
            elif isinstance(dataset_path['test'], list):
                test_dataset = concatenate_datasets(dataset_path['test'])

            N = dataset['observations'].shape[0] * dataset[
                'observations'].shape[1]
            n_random_steps = dataset['observations'].shape[1]
            use_test_dataset = True
    else:
        if env_kwargs is None:
            env_kwargs = {}
        if save_file_prefix is None:
            save_file_prefix = env_id
        if save_file_prefix is None:
            save_file_prefix = env_class.__name__
        filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format(
            save_file_prefix,
            str(N),
            init_camera.__name__
            if init_camera and hasattr(init_camera, '__name__') else '',
            imsize,
            random_and_oracle_policy_data_split,
            tag,
        )
        if use_cached and osp.isfile(filename):
            dataset = load_local_or_remote_file(
                filename, delete_after_loading=delete_after_loading)
            if conditional_vae_dataset:
                dataset = dataset.item()
            print("loaded data from saved file", filename)
        else:
            now = time.time()
            if env_id is not None:
                import gym
                import multiworld
                multiworld.register_all_envs()
                env = gym.make(env_id)
            else:
                if vae_dataset_specific_env_kwargs is None:
                    vae_dataset_specific_env_kwargs = {}
                for key, val in env_kwargs.items():
                    if key not in vae_dataset_specific_env_kwargs:
                        vae_dataset_specific_env_kwargs[key] = val
                env = env_class(**vae_dataset_specific_env_kwargs)
            if not isinstance(env, ImageEnv):
                env = ImageEnv(
                    env,
                    imsize,
                    init_camera=init_camera,
                    transpose=True,
                    normalize=True,
                    non_presampled_goal_img_is_garbage=
                    non_presampled_goal_img_is_garbage,
                )
            else:
                imsize = env.imsize
                env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage
            env.reset()
            info['env'] = env
            if random_and_oracle_policy_data:
                policy_file = load_local_or_remote_file(policy_file)
                policy = policy_file['policy']
                policy.to(ptu.device)
            if random_rollout_data:
                from rlkit.exploration_strategies.ou_strategy import OUStrategy
                policy = OUStrategy(env.action_space)

            if save_trajectories:
                dataset = {
                    'observations':
                    np.zeros((N // n_random_steps, n_random_steps,
                              imsize * imsize * num_channels),
                             dtype=np.uint8),
                    'actions':
                    np.zeros((N // n_random_steps, n_random_steps,
                              env.action_space.shape[0]),
                             dtype=np.float),
                    'env':
                    np.zeros(
                        (N // n_random_steps, imsize * imsize * num_channels),
                        dtype=np.uint8),
                }
            else:
                dataset = np.zeros((N, imsize * imsize * num_channels),
                                   dtype=np.uint8)
            labels = []
            for i in tqdm(range(N)):
                if random_and_oracle_policy_data:
                    num_random_steps = int(N *
                                           random_and_oracle_policy_data_split)
                    if i < num_random_steps:
                        env.reset()
                        for _ in range(n_random_steps):
                            obs = env.step(env.action_space.sample())[0]
                    else:
                        obs = env.reset()
                        policy.reset()
                        for _ in range(n_random_steps):
                            policy_obs = np.hstack((
                                obs['state_observation'],
                                obs['state_desired_goal'],
                            ))
                            action, _ = policy.get_action(policy_obs)
                            obs, _, _, _ = env.step(action)
                elif random_rollout_data:  #ADD DATA WHERE JUST PUCK MOVES
                    if i % n_random_steps == 0:
                        env.reset()
                        policy.reset()
                        env_img = env._get_obs()['image_observation']
                        if random_rollout_data_set_to_goal:
                            env.set_to_goal(env.get_goal())
                    obs = env._get_obs()
                    u = policy.get_action_from_raw_action(
                        env.action_space.sample())
                    env.step(u)
                elif oracle_dataset_using_set_to_goal:
                    print(i)

                    goal = env.sample_goal()
                    env.set_to_goal(goal)
                    obs = env._get_obs()
                else:
                    env.reset()
                    for _ in range(n_random_steps):
                        obs = env.step(env.action_space.sample())[0]

                img = obs['image_observation']
                if use_env_labels:
                    labels.append(obs['label'])
                if save_trajectories:
                    dataset['observations'][
                        i // n_random_steps,
                        i % n_random_steps, :] = unormalize_image(img)
                    dataset['actions'][i // n_random_steps,
                                       i % n_random_steps, :] = u
                    dataset['env'][i // n_random_steps, :] = unormalize_image(
                        env_img)
                else:
                    dataset[i, :] = unormalize_image(img)

                if show:
                    img = img.reshape(3, imsize, imsize).transpose()
                    img = img[::-1, :, ::-1]
                    cv2.imshow('img', img)
                    cv2.waitKey(1)
                    # radius = input('waiting...')
            print("done making training data", filename, time.time() - now)
            np.save(filename, dataset)
            #np.save(filename[:-4] + 'labels.npy', np.array(labels))

    info['train_labels'] = []
    info['test_labels'] = []

    dataset = data_filter_fn(dataset)
    if use_linear_dynamics and conditional_vae_dataset:
        num_trajectories = N // n_random_steps
        n = int(num_trajectories * test_p)
        train_dataset = ConditionalDynamicsDataset({
            'observations':
            dataset['observations'][:n, :, :],
            'actions':
            dataset['actions'][:n, :, :],
            'env':
            dataset['env'][:n, :]
        })
        test_dataset = ConditionalDynamicsDataset({
            'observations':
            dataset['observations'][n:, :, :],
            'actions':
            dataset['actions'][n:, :, :],
            'env':
            dataset['env'][n:, :]
        })

        num_trajectories = N // n_random_steps
        n = int(num_trajectories * test_p)
        indices = np.arange(num_trajectories)
        np.random.shuffle(indices)
        train_i, test_i = indices[:n], indices[n:]

        try:
            train_dataset = ConditionalDynamicsDataset({
                'observations':
                dataset['observations'][train_i, :, :],
                'actions':
                dataset['actions'][train_i, :, :],
                'env':
                dataset['env'][train_i, :]
            })
            test_dataset = ConditionalDynamicsDataset({
                'observations':
                dataset['observations'][test_i, :, :],
                'actions':
                dataset['actions'][test_i, :, :],
                'env':
                dataset['env'][test_i, :]
            })
        except:
            train_dataset = ConditionalDynamicsDataset({
                'observations':
                dataset['observations'][train_i, :, :],
                'actions':
                dataset['actions'][train_i, :, :],
            })
            test_dataset = ConditionalDynamicsDataset({
                'observations':
                dataset['observations'][test_i, :, :],
                'actions':
                dataset['actions'][test_i, :, :],
            })
    elif use_linear_dynamics:
        num_trajectories = N // n_random_steps
        n = int(num_trajectories * test_p)
        train_dataset = TrajectoryDataset({
            'observations':
            dataset['observations'][:n, :, :],
            'actions':
            dataset['actions'][:n, :, :]
        })
        test_dataset = TrajectoryDataset({
            'observations':
            dataset['observations'][n:, :, :],
            'actions':
            dataset['actions'][n:, :, :]
        })
    elif enviorment_dataset:
        n = int(n_random_steps * test_p)
        train_dataset = EnvironmentDataset({
            'observations':
            dataset['observations'][:, :n, :],
        })
        test_dataset = EnvironmentDataset({
            'observations':
            dataset['observations'][:, n:, :],
        })
    elif conditional_vae_dataset:
        num_trajectories = N // n_random_steps
        n = int(num_trajectories * test_p)
        indices = np.arange(num_trajectories)
        np.random.shuffle(indices)
        train_i, test_i = indices[:n], indices[n:]

        if augment_data:
            dataset_class = InitialObservationNumpyJitteringDataset
        else:
            dataset_class = InitialObservationNumpyDataset

        if 'env' not in dataset:
            dataset['env'] = dataset['observations'][:, 0]
        if use_test_dataset and ('env' not in test_dataset):
            test_dataset['env'] = test_dataset['observations'][:, 0]

        if use_test_dataset:
            train_dataset = dataset_class({
                'observations':
                dataset['observations'],
                'env':
                dataset['env']
            })

            test_dataset = dataset_class({
                'observations':
                test_dataset['observations'],
                'env':
                test_dataset['env']
            })
        else:
            train_dataset = dataset_class({
                'observations':
                dataset['observations'][train_i, :, :],
                'env':
                dataset['env'][train_i, :]
            })

            test_dataset = dataset_class({
                'observations':
                dataset['observations'][test_i, :, :],
                'env':
                dataset['env'][test_i, :]
            })

        train_batch_loader_kwargs = variant.get(
            'train_batch_loader_kwargs',
            dict(
                batch_size=batch_size,
                num_workers=0,
            ))
        test_batch_loader_kwargs = variant.get(
            'test_batch_loader_kwargs',
            dict(
                batch_size=batch_size,
                num_workers=0,
            ))

        train_data_loader = data.DataLoader(train_dataset,
                                            shuffle=True,
                                            drop_last=True,
                                            **train_batch_loader_kwargs)
        test_data_loader = data.DataLoader(test_dataset,
                                           shuffle=True,
                                           drop_last=True,
                                           **test_batch_loader_kwargs)

        train_dataset = InfiniteBatchLoader(train_data_loader)
        test_dataset = InfiniteBatchLoader(test_data_loader)
    else:
        n = int(N * test_p)
        train_dataset = ImageObservationDataset(dataset[:n, :])
        test_dataset = ImageObservationDataset(dataset[n:, :])
    return train_dataset, test_dataset, info