def generate_uniform_dataset_reacher(env_class=None,
                                     env_kwargs=None,
                                     num_imgs=1000,
                                     use_cached_dataset=False,
                                     init_camera=None,
                                     imsize=48,
                                     show=False,
                                     save_file_prefix=None,
                                     env_id=None,
                                     tag='',
                                     dataset_path=None):
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        return dataset
    import gym
    from gym.envs import registration
    # trigger registration
    import multiworld.envs.pygame
    import multiworld.envs.mujoco
    if not env_class or not env_kwargs:
        env = gym.make(env_id)
    else:
        env = env_class(**env_kwargs)
    env = ImageEnv(
        env,
        imsize,
        init_camera=init_camera,
        transpose=True,
        normalize=True,
    )
    env.non_presampled_goal_img_is_garbage = True
    if save_file_prefix is None and env_id is not None:
        save_file_prefix = env_id
    filename = "/tmp/{}_N{}_imsize{}uniform_images_{}.npy".format(
        save_file_prefix,
        str(num_imgs),
        env.imsize,
        tag,
    )
    if use_cached_dataset and osp.isfile(filename):
        images = np.load(filename)
        print("Loaded data from {}".format(filename))
        return images

    print('Sampling Uniform Dataset')
    dataset = np.zeros((num_imgs, 3 * env.imsize**2), dtype=np.uint8)
    for j in range(num_imgs):
        obs = env.reset()
        env.set_to_goal(env.get_goal())
        img_f = env._get_flat_img()
        if show:
            img = img_f.reshape(3, env.imsize, env.imsize).transpose()
            img = img[::-1, :, ::-1]
            cv2.imshow('img', img)
            cv2.waitKey(1)
        print(j)
        dataset[j, :] = unormalize_image(img_f)
    np.save(filename, dataset)
    print("Saving file to {}".format(filename))
    return dataset
示例#2
0
    def __init__(
        self,
        wrapped_env,
        vae,
        pixel_cnn=None,
        vae_input_key_prefix='image',
        sample_from_true_prior=False,
        decode_goals=False,
        decode_goals_on_reset=True,
        render_goals=False,
        render_rollouts=False,
        reward_params=None,
        goal_sampling_mode="vae_prior",
        imsize=84,
        obs_size=None,
        norm_order=2,
        epsilon=20,
        presampled_goals=None,
    ):
        if reward_params is None:
            reward_params = dict()
        super().__init__(
            wrapped_env,
            vae,
            vae_input_key_prefix,
            sample_from_true_prior,
            decode_goals,
            decode_goals_on_reset,
            render_goals,
            render_rollouts,
            reward_params,
            goal_sampling_mode,
            imsize,
            obs_size,
            norm_order,
            epsilon,
            presampled_goals,
        )

        if type(pixel_cnn) is str:
            self.pixel_cnn = load_local_or_remote_file(pixel_cnn)
        self.num_keys = self.vae.num_embeddings
        self.representation_size = 144 * self.vae.representation_size
        print("*******UNFIX REPRESENTATION SIZE*********")
        print("Location: VQVAE WRAPPER")

        latent_space = Box(
            -10 * np.ones(obs_size or self.representation_size),
            10 * np.ones(obs_size or self.representation_size),
            dtype=np.float32,
        )

        spaces = self.wrapped_env.observation_space.spaces
        spaces['observation'] = latent_space
        spaces['desired_goal'] = latent_space
        spaces['achieved_goal'] = latent_space
        spaces['latent_observation'] = latent_space
        spaces['latent_desired_goal'] = latent_space
        spaces['latent_achieved_goal'] = latent_space
        self.observation_space = Dict(spaces)
    def load_demo_path(self,
                       path,
                       is_demo,
                       obs_dict,
                       train_split=None,
                       data_split=None):
        print("loading off-policy path", path)
        data = list(load_local_or_remote_file(path))
        # if not is_demo:
        # data = [data]
        # random.shuffle(data)

        if train_split is None:
            train_split = self.demo_train_split

        if data_split is None:
            data_split = self.demo_data_split

        M = int(len(data) * train_split * data_split)
        N = int(len(data) * data_split)
        print("using", N, "paths for training")

        if self.add_demos_to_replay_buffer:
            for path in data[:M]:
                self.load_path(path, self.replay_buffer, obs_dict=obs_dict)

        if is_demo:
            for path in data[:M]:
                self.load_path(path, self.demo_train_buffer, obs_dict=obs_dict)
            for path in data[M:N]:
                self.load_path(path, self.demo_test_buffer, obs_dict=obs_dict)
def load_dataset(filename, test_p=0.9):
    dataset = load_local_or_remote_file(filename).item()

    num_trajectories = dataset["observations"].shape[0]
    n_random_steps = dataset["observations"].shape[1]
    #num_trajectories = N // n_random_steps
    n = int(num_trajectories * test_p)

    try:
        train_dataset = InitialObservationDataset({
            'observations':
            dataset['observations'][:n, :, :],
            'env':
            dataset['env'][:n, :],
        })
        test_dataset = InitialObservationDataset({
            'observations':
            dataset['observations'][n:, :, :],
            'env':
            dataset['env'][n:, :],
        })
    except:
        train_dataset = InitialObservationDataset({
            'observations':
            dataset['observations'][:n, :, :],
        })
        test_dataset = InitialObservationDataset({
            'observations':
            dataset['observations'][n:, :, :],
        })

    return train_dataset, test_dataset
示例#5
0
    def load_dataset(self, dataset_path):
        dataset = load_local_or_remote_file(dataset_path)
        dataset = dataset.item()

        observations = dataset['observations']
        actions = dataset['actions']

        # dataset['observations'].shape # (2000, 50, 6912)
        # dataset['actions'].shape # (2000, 50, 2)
        # dataset['env'].shape # (2000, 6912)
        N, H, imlength = observations.shape

        self.vae.eval()
        for n in range(N):
            x0 = ptu.from_numpy(dataset['env'][n:n + 1, :] / 255.0)
            x = ptu.from_numpy(observations[n, :, :] / 255.0)
            latents = self.vae.encode(x, x0, distrib=False)

            r1, r2 = self.vae.latent_sizes
            conditioning = latents[0, r1:]
            goal = torch.cat(
                [ptu.randn(self.vae.latent_sizes[0]), conditioning])
            goal = ptu.get_numpy(goal)  # latents[-1, :]

            latents = ptu.get_numpy(latents)
            latent_delta = latents - goal
            distances = np.zeros((H - 1, 1))
            for i in range(H - 1):
                distances[i, 0] = np.linalg.norm(latent_delta[i + 1, :])

            terminals = np.zeros((H - 1, 1))
            # terminals[-1, 0] = 1
            path = dict(
                observations=[],
                actions=actions[n, :H - 1, :],
                next_observations=[],
                rewards=-distances,
                terminals=terminals,
            )

            for t in range(H - 1):
                # reward = -np.linalg.norm(latent_delta[i, :])

                obs = dict(
                    latent_observation=latents[t, :],
                    latent_achieved_goal=latents[t, :],
                    latent_desired_goal=goal,
                )
                next_obs = dict(
                    latent_observation=latents[t + 1, :],
                    latent_achieved_goal=latents[t + 1, :],
                    latent_desired_goal=goal,
                )

                path['observations'].append(obs)
                path['next_observations'].append(next_obs)

            # import ipdb; ipdb.set_trace()
            self.replay_buffer.add_path(path)
示例#6
0
    def load_demos(self, demo_path):
        data = load_local_or_remote_file(demo_path)
        random.shuffle(data)
        N = int(len(data) * self.train_split)
        print("using", N, "paths for training")
        for path in data[:N]:
            self.load_path(path, self.replay_buffer)

        for path in data[N:]:
            self.load_path(path, self.test_replay_buffer)
示例#7
0
def generate_vae_dataset_from_demos(variant):
    demo_path = variant["demo_path"]
    test_p = variant.get('test_p', 0.9)
    use_cached = variant.get('use_cached', True)
    imsize = variant.get('imsize', 84)
    num_channels = variant.get('num_channels', 3)
    show = variant.get('show', False)
    init_camera = variant.get('init_camera', None)

    def load_paths(paths):
        data = [load_path(path) for path in paths]
        data = np.concatenate(data, 0)
        return data

    def load_path(path):
        N = len(path["observations"])
        data = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8)
        i = 0
        for (
                ob,
                action,
                reward,
                next_ob,
                terminal,
                agent_info,
                env_info,
        ) in zip(
                path["observations"],
                path["actions"],
                path["rewards"],
                path["next_observations"],
                path["terminals"],
                path["agent_infos"],
                path["env_infos"],
        ):
            img = ob["image_observation"]
            img = img.reshape(imsize, imsize, 3).transpose()
            data[i, :] = img.flatten()
            i += 1
        return data

    data = load_local_or_remote_file(demo_path)
    random.shuffle(data)
    N = int(len(data) * test_p)
    print("using", N, "paths for training")

    train_data = load_paths(data[:N])
    test_data = load_paths(data[N:])

    print("training data shape", train_data.shape)
    print("test data shape", test_data.shape)

    info = {}

    return train_data, test_data, info
def resume(variant):
    data = load_local_or_remote_file(variant.get("pretrained_algorithm_path"),
                                     map_location="cuda")
    algo = data['algorithm']

    algo.num_epochs = variant['num_epochs']

    post_pretrain_hyperparams = variant["trainer_kwargs"].get(
        "post_pretrain_hyperparams", {})
    algo.trainer.set_algorithm_weights(**post_pretrain_hyperparams)

    algo.train()
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs'])
    uniform_dataset = load_local_or_remote_file(
        variant['uniform_dataset_path']).item()
    uniform_dataset = unormalize_image(uniform_dataset['image_desired_goal'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size,
                       decoder_output_activation=nn.Sigmoid(),
                       **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.log_loss_under_uniform(
            m, uniform_dataset,
            variant['algo_kwargs']['priority_function_kwargs'])
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
                t.dump_uniform_imgs_and_reconstructions(
                    dataset=uniform_dataset, epoch=epoch)
        t.update_train_weights()
示例#10
0
def generate_uniform_dataset_pick_and_place(env_class=None,
                                            env_kwargs=None,
                                            num_imgs=1000,
                                            use_cached_dataset=False,
                                            init_camera=None,
                                            imsize=48,
                                            save_file_prefix=None,
                                            env_id=None,
                                            tag='',
                                            dataset_path=None):
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        return dataset
    import gym
    from gym.envs import registration
    # trigger registration
    import multiworld.envs.pygame
    import multiworld.envs.mujoco
    if not env_class or not env_kwargs:
        env = gym.make(env_id)
    else:
        env = env_class(**env_kwargs)
    env = ImageEnv(
        env,
        imsize,
        init_camera=init_camera,
        transpose=True,
        normalize=True,
    )
    env.non_presampled_goal_img_is_garbage = True
    if save_file_prefix is None and env_id is not None:
        save_file_prefix = env_id
    filename = "/tmp/{}_N{}_imsize{}uniform_images_{}.npy".format(
        save_file_prefix,
        str(num_imgs),
        env.imsize,
        tag,
    )
    if use_cached_dataset and osp.isfile(filename):
        images = np.load(filename)
        print("Loaded data from {}".format(filename))
        return images

    print('Sampling Uniform Dataset')
    dataset = unormalize_image(
        get_image_presampled_goals(env, num_imgs)['image_desired_goal'])
    np.save(filename, dataset)
    print("Saving file to {}".format(filename))
    return dataset
    def load_demo_path(self, demo_path, on_policy=True):
        data = list(load_local_or_remote_file(demo_path))
        # if not on_policy:
        # data = [data]
        # random.shuffle(data)
        N = int(len(data) * self.demo_train_split)
        print("using", N, "paths for training")

        if self.add_demos_to_replay_buffer:
            for path in data[:N]:
                self.load_path(path, self.replay_buffer)
        if on_policy:
            for path in data[:N]:
                self.load_path(path, self.demo_train_buffer)
            for path in data[N:]:
                self.load_path(path, self.demo_test_buffer)
示例#12
0
from multiworld.core.image_env import ImageEnv
from multiworld.envs.mujoco.cameras import sawyer_init_camera_zoomed_in
import numpy as np
from railrl.demos.collect_demo import collect_demos
from railrl.misc.asset_loader import load_local_or_remote_file

if __name__ == '__main__':
    data = load_local_or_remote_file('/home/murtaza/research/railrl/data/doodads3/11-16-pusher-state-td3-sweep-params-policy-update-period/11-16-pusher_state_td3_sweep_params_policy_update_period_2019_11_17_00_28_45_id000--s62098/params.pkl')
    env = data['evaluation/env']
    policy = data['trainer/trained_policy']
    image_env = ImageEnv(
        env,
        48,
        init_camera=sawyer_init_camera_zoomed_in,
        transpose=True,
        normalize=True,
    )
    collect_demos(image_env, policy, "data/local/demos/pusher_demos_action_noise_1000.npy", N=1000, horizon=50, threshold=.1, add_action_noise=False, key='puck_distance', render=True, noise_sigma=0.0)
    # data = load_local_or_remote_file("demos/pusher_demos_1000.npy")
    # for i in range(100):
    #     goal = data[i]['observations'][49]['desired_goal']
    #     o = env.reset()
    #     path_length = 0
    #     while path_length < 50:
    #         env.set_goal({'state_desired_goal':goal})
    #         o = o['state_observation']
    #         new_obs = np.hstack((o, goal))
    #         a, agent_info = policy.get_action(new_obs)
    #         o, r, d, env_info = env.step(a)
    #         path_length += 1
    #     print(i, env_info['puck_distance'])
from railrl.data_management.images import normalize_image
import matplotlib.pyplot as plt

dataset_path = "/tmp/SawyerMultiobjectEnv_N5000_sawyer_init_camera_zoomed_in_imsize48_random_oracle_split_0.npy"
cvae_path = "/home/khazatsky/rail/data/rail-khazatsky/sasha/PCVAE/DCVAE/run103/id0/vae.pkl"
vae_path = "/home/khazatsky/rail/data/rail-khazatsky/sasha/PCVAE/baseline/run1/id0/itr_300.pkl"
prefix = "pusher1_"

# dataset_path = "/tmp/Multiobj2DEnv_N100000_sawyer_init_camera_zoomed_in_imsize48_random_oracle_split_0.npy"
# cvae_path = "/home/khazatsky/rail/data/rail-khazatsky/sasha/PCVAE/dynamics-cvae/run1/id0/itr_500.pkl"
# vae_path = "/home/khazatsky/rail/data/rail-khazatsky/sasha/PCVAE/baseline/run4/id0/itr_300.pkl"
# prefix = "pointmass1_"

N_ROWS = 3

dataset = load_local_or_remote_file(dataset_path)
dataset = dataset.item()

imlength = 6912
imsize = 48

N = dataset['observations'].shape[0]
test_p = 0.9
t = 0  #int(test_p * N)
n = 50
cvae = load_local_or_remote_file(cvae_path)
cvae.eval()
model = cvae.cpu()

cvae_distances = np.zeros((N - t, n))
for j in range(t, N):
    def __init__(
        self,
        wrapped_env,
        vae,
        vae_input_key_prefix='image',
        sample_from_true_prior=False,
        decode_goals=False,
        decode_goals_on_reset=True,
        render_goals=False,
        render_rollouts=False,
        reward_params=None,
        goal_sampling_mode="vae_prior",
        imsize=84,
        obs_size=None,
        norm_order=2,
        epsilon=20,
        presampled_goals=None,
    ):
        if reward_params is None:
            reward_params = dict()
        super().__init__(wrapped_env)
        if type(vae) is str:
            self.vae = load_local_or_remote_file(vae)
        else:
            self.vae = vae
        self.representation_size = self.vae.representation_size
        self.input_channels = self.vae.input_channels
        self.sample_from_true_prior = sample_from_true_prior
        self._decode_goals = decode_goals
        self.render_goals = render_goals
        self.render_rollouts = render_rollouts
        self.default_kwargs=dict(
            decode_goals=decode_goals,
            render_goals=render_goals,
            render_rollouts=render_rollouts,
        )
        self.imsize = imsize
        self.reward_params = reward_params
        self.reward_type = self.reward_params.get("type", 'latent_distance')
        self.norm_order = self.reward_params.get("norm_order", norm_order)
        self.epsilon = self.reward_params.get("epsilon", epsilon)
        self.reward_min_variance = self.reward_params.get("min_variance", 0)
        self.decode_goals_on_reset = decode_goals_on_reset

        latent_space = Box(
            -10 * np.ones(obs_size or self.representation_size),
            10 * np.ones(obs_size or self.representation_size),
            dtype=np.float32,
        )

        spaces = self.wrapped_env.observation_space.spaces
        spaces['observation'] = latent_space
        spaces['desired_goal'] = latent_space
        spaces['achieved_goal'] = latent_space
        spaces['latent_observation'] = latent_space
        spaces['latent_desired_goal'] = latent_space
        spaces['latent_achieved_goal'] = latent_space
        self.observation_space = Dict(spaces)
        self._presampled_goals = presampled_goals
        if self._presampled_goals is None:
            self.num_goals_presampled = 0
        else:
            self.num_goals_presampled = presampled_goals[random.choice(list(presampled_goals))].shape[0]

        self.vae_input_key_prefix = vae_input_key_prefix
        assert vae_input_key_prefix in {'image', 'image_proprio'}
        self.vae_input_observation_key = vae_input_key_prefix + '_observation'
        self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal'
        self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal'
        self._mode_map = {}
        self.desired_goal = {'latent_desired_goal': latent_space.sample()}
        self._initial_obs = None
        self._custom_goal_sampler = None
        self._goal_sampling_mode = goal_sampling_mode
示例#15
0
from multiworld.core.image_env import ImageEnv
from multiworld.envs.mujoco.cameras import sawyer_init_camera_zoomed_in
import numpy as np
from railrl.demos.collect_demo import collect_demos_fixed
from railrl.misc.asset_loader import load_local_or_remote_file

from railrl.launchers.experiments.ashvin.awr_sac_rl import ENV_PARAMS

if __name__ == '__main__':
    data = load_local_or_remote_file(
        'ashvin/icml2020/mujoco/reference/run1/id2/itr_200.pkl')
    env = data['evaluation/env']
    policy = data['evaluation/policy']
    # import ipdb; ipdb.set_trace()
    # policy =
    policy.to("cpu")
    # image_env = ImageEnv(
    #     env,
    #     48,
    #     init_camera=sawyer_init_camera_zoomed_in,
    #     transpose=True,
    #     normalize=True,
    # )
    env_name = pendulum
    outfile = "/home/ashvin/data/s3doodad/demos/icml2020/mujoco/%s.npy" % env_name
    horizon = ENV_PARAMS[env_name]['max_path_length']
    collect_demos_fixed(
        env, policy, outfile, N=100, horizon=horizon
    )  # , threshold=.1, add_action_noise=False, key='puck_distance', render=True, noise_sigma=0.0)
    # data = load_local_or_remote_file("demos/pusher_demos_1000.npy")
    # for i in range(100):
示例#16
0
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv
    from railrl.misc.asset_loader import load_local_or_remote_file

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reproj_vae_path = variant.get("reproj_vae_path", None)
    ckpt = variant.get("ckpt", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)

    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only', False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    if not do_state_exp and type(ckpt) is str:
        vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl'))
        if vae is not None:
            from railrl.core import logger
            logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    else:
        vae = None

    if vae is None and type(vae_path) is str:
        vae = load_local_or_remote_file(osp.join(vae_path, 'vae_params.pkl'))
        from railrl.core import logger

        logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    elif vae is None:
        vae = vae_path

    if type(vae) is str:
        vae = load_local_or_remote_file(vae)
    else:
        vae = vae

    if type(reproj_vae_path) is str:
        reproj_vae = load_local_or_remote_file(osp.join(reproj_vae_path, 'vae_params.pkl'))
    else:
        reproj_vae = None

    if 'env_id' in variant:
        import gym
        # trigger registration
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        vae_env = VAEWrappedEnv(
            image_env,
            vae,
            imsize=image_env.imsize,
            decode_goals=render,
            render_goals=render,
            render_rollouts=render,
            reward_params=reward_params,
            reproj_vae=reproj_vae,
            **variant.get('vae_wrapped_env_kwargs', {})
        )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=image_env,
                    **variant['goal_generation_kwargs']
                )
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path
                ).item()
                presampled_goals = {
                    'state_desired_goal': presampled_goals['next_obs_state'],
                    'image_desired_goal': presampled_goals['next_obs'],
                }

            image_env.set_presampled_goals(presampled_goals)
            vae_env.set_presampled_goals(presampled_goals)
            print("Presampling all goals")
        else:
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs']
                )
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    if not do_state_exp:
        training_mode = variant.get("training_mode", "train")
        testing_mode = variant.get("testing_mode", "test")
        env.add_mode('eval', testing_mode)
        env.add_mode('train', training_mode)
        env.add_mode('relabeling', training_mode)
        # relabeling_env.disable_render()
        env.add_mode("video_vae", 'video_vae')
        env.add_mode("video_env", 'video_env')
    return env
示例#17
0
from multiworld.core.image_env import ImageEnv
from multiworld.envs.mujoco.cameras import sawyer_init_camera_zoomed_in
import numpy as np
from railrl.demos.collect_demo import collect_demos
from railrl.misc.asset_loader import load_local_or_remote_file

if __name__ == '__main__':
    data = load_local_or_remote_file('ashvin/icml2020/murtaza/pusher/state/run3/id3/itr_980.pkl')
    env = data['evaluation/env']
    policy = data['trainer/trained_policy']
    policy = policy.to("cpu")
    image_env = ImageEnv(
        env,
        48,
        init_camera=sawyer_init_camera_zoomed_in,
        transpose=True,
        normalize=True,
    )
    collect_demos(image_env, policy, "/home/ashvin/data/s3doodad/demos/icml2020/pusher/demos_action_noise_1000.npy", N=1000, horizon=50, threshold=.1, add_action_noise=False, key='puck_distance', render=True, noise_sigma=0.0)
    # data = load_local_or_remote_file("demos/pusher_demos_1000.npy")
    # for i in range(100):
    #     goal = data[i]['observations'][49]['desired_goal']
    #     o = env.reset()
    #     path_length = 0
    #     while path_length < 50:
    #         env.set_goal({'state_desired_goal':goal})
    #         o = o['state_observation']
    #         new_obs = np.hstack((o, goal))
    #         a, agent_info = policy.get_action(new_obs)
    #         o, r, d, env_info = env.step(a)
    #         path_length += 1
示例#18
0
def state_td3bc_experiment(variant):
    if variant.get('env_id', None):
        import gym
        import multiworld
        multiworld.register_all_envs()
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es_strat =  variant.get('es', 'ou')
    if es_strat == 'ou':
        es = OUStrategy(
            action_space=expl_env.action_space,
            max_sigma=variant['exploration_noise'],
            min_sigma=variant['exploration_noise'],
        )
    elif es_strat == 'gauss_eps':
        es = GaussianAndEpislonStrategy(
            action_space=expl_env.action_space,
            max_sigma=.2,
            min_sigma=.2,  # constant sigma
            epsilon=.3,
        )
    else:
        raise ValueError("invalid exploration strategy provided")
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    demo_train_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        max_size=variant['replay_buffer_kwargs']['max_size']
    )
    demo_test_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        max_size=variant['replay_buffer_kwargs']['max_size'],
    )
    if variant.get('td3_bc', True):
        td3_trainer = TD3BCTrainer(
            env=expl_env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            replay_buffer=replay_buffer,
            demo_train_buffer=demo_train_buffer,
            demo_test_buffer=demo_test_buffer,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['td3_bc_trainer_kwargs']
        )
    else:
        td3_trainer = TD3(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['td3_trainer_kwargs']
        )
    trainer = HERTrainer(td3_trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )

    if variant.get("save_video", True):
        if variant.get("presampled_goals", None):
            variant['image_env_kwargs']['presampled_goals'] = load_local_or_remote_file(variant['presampled_goals']).item()
        image_eval_env = ImageEnv(eval_env, **variant["image_env_kwargs"])
        image_eval_path_collector = GoalConditionedPathCollector(
            image_eval_env,
            policy,
            observation_key='state_observation',
            desired_goal_key='state_desired_goal',
        )
        image_expl_env = ImageEnv(expl_env, **variant["image_env_kwargs"])
        image_expl_path_collector = GoalConditionedPathCollector(
            image_expl_env,
            expl_policy,
            observation_key='state_observation',
            desired_goal_key='state_desired_goal',
        )
        video_func = VideoSaveFunction(
            image_eval_env,
            variant,
            image_expl_path_collector,
            image_eval_path_collector,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if variant.get('load_demos', False):
        td3_trainer.load_demos()
    if variant.get('pretrain_policy', False):
        td3_trainer.pretrain_policy_with_bc()
    if variant.get('pretrain_rl', False):
        td3_trainer.pretrain_q_with_bc_data()
    algorithm.train()
def generate_goal_dataset_using_policy(
        env=None,
        num_goals=1000,
        use_cached_dataset=False,
        policy_file=None,
        show=False,
        path_length=500,
        save_file_prefix=None,
        env_id=None,
        tag='',
):
    if isinstance(env, ImageEnv):
        env_class_name = env._wrapped_env.__class__.__name__
    else:
        env_class_name = env._wrapped_env.wrapped_env.__class__.__name__
    if save_file_prefix is None and env_id is not None:
        save_file_prefix = env_id
    elif save_file_prefix is None:
        save_file_prefix = env_class_name
    filename = "/tmp/{}_N{}_imsize{}goals{}.npy".format(
        save_file_prefix,
        str(num_goals),
        env.imsize,
        tag,
    )
    if use_cached_dataset and osp.isfile(filename):
        goal_dict = np.load(filename).item()
        print("Loaded data from {}".format(filename))
        return goal_dict

    goal_generation_dict = dict()
    for goal_key, obs_key in [
        ('image_desired_goal', 'image_achieved_goal'),
        ('state_desired_goal', 'state_achieved_goal'),
    ]:
        goal_size = env.observation_space.spaces[goal_key].low.size
        goal_generation_dict[goal_key] = [goal_size, obs_key]

    goal_dict = dict()
    policy_file = load_local_or_remote_file(policy_file)
    policy = policy_file['policy']
    policy.to(ptu.device)
    for goal_key in goal_generation_dict:
        goal_size, obs_key = goal_generation_dict[goal_key]
        goal_dict[goal_key] = np.zeros((num_goals, goal_size))
    print('Generating Random Goals')
    for j in range(num_goals):
        obs = env.reset()
        policy.reset()
        for i in range(path_length):
            policy_obs = np.hstack((
                obs['state_observation'],
                obs['state_desired_goal'],
            ))
            action, _ = policy.get_action(policy_obs)
            obs, _, _, _ = env.step(action)
        if show:
            img = obs['image_observation']
            img = img.reshape(3, env.imsize, env.imsize).transpose()
            img = img[::-1, :, ::-1]
            cv2.imshow('img', img)
            cv2.waitKey(1)

        for goal_key in goal_generation_dict:
            goal_size, obs_key = goal_generation_dict[goal_key]
            goal_dict[goal_key][j, :] = obs[obs_key]
    np.save(filename, goal_dict)
    print("Saving file to {}".format(filename))
    return goal_dict
示例#20
0
def generate_uniform_dataset_door(num_imgs=1000,
                                  use_cached_dataset=False,
                                  init_camera=None,
                                  imsize=48,
                                  policy_file=None,
                                  show=False,
                                  path_length=100,
                                  save_file_prefix=None,
                                  env_id=None,
                                  tag='',
                                  dataset_path=None):
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        return dataset
    import gym
    from gym.envs import registration
    # trigger registration
    import multiworld.envs.pygame
    import multiworld.envs.mujoco
    env = gym.make(env_id)
    env = ImageEnv(
        env,
        imsize,
        init_camera=init_camera,
        transpose=True,
        normalize=True,
    )
    env.non_presampled_goal_img_is_garbage = True
    if save_file_prefix is None and env_id is not None:
        save_file_prefix = env_id
    filename = "/tmp/{}_N{}_imsize{}uniform_images_{}.npy".format(
        save_file_prefix,
        str(num_imgs),
        env.imsize,
        tag,
    )
    if use_cached_dataset and osp.isfile(filename):
        images = np.load(filename)
        print("Loaded data from {}".format(filename))
        return images

    policy_file = load_local_or_remote_file(policy_file)
    policy = policy_file['policy']
    policy.to(ptu.device)
    print('Sampling Uniform Dataset')
    dataset = np.zeros((num_imgs, 3 * env.imsize**2), dtype=np.uint8)
    for j in range(num_imgs):
        obs = env.reset()
        policy.reset()
        for i in range(path_length):
            policy_obs = np.hstack((
                obs['state_observation'],
                obs['state_desired_goal'],
            ))
            action, _ = policy.get_action(policy_obs)
            obs, _, _, _ = env.step(action)
        img_f = obs['image_observation']
        if show:
            img = obs['image_observation']
            img = img.reshape(3, env.imsize, env.imsize).transpose()
            img = img[::-1, :, ::-1]
            cv2.imshow('img', img)
            cv2.waitKey(1)
        print(j)
        dataset[j, :] = unormalize_image(img_f)
    temp = env.reset_free
    env.reset_free = True
    env.reset()
    env.reset_free = temp
    np.save(filename, dataset)
    print("Saving file to {}".format(filename))
    return dataset
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv, ConditionalVAEWrappedEnv
    from railrl.misc.asset_loader import load_local_or_remote_file
    from railrl.torch.vae.conditional_conv_vae import CVAE, CDVAE, ACE, CADVAE, DeltaCVAE

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)
    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = get_presampled_goals_path(
        variant.get('presampled_goals_path', None))
    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if 'env_id' in variant:
        import gym
        import multiworld
        multiworld.register_all_envs()
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                vae_env = VAEWrappedEnv(image_env,
                                        vae,
                                        imsize=image_env.imsize,
                                        decode_goals=render,
                                        render_goals=render,
                                        render_rollouts=render,
                                        reward_params=reward_params,
                                        **variant.get('vae_wrapped_env_kwargs',
                                                      {}))
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    env=vae_env,
                    env_id=variant.get('env_id', None),
                    **variant['goal_generation_kwargs'])
                del vae_env
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path).item()
            del image_env
            image_env = ImageEnv(env,
                                 variant.get('imsize'),
                                 init_camera=init_camera,
                                 transpose=True,
                                 normalize=True,
                                 presampled_goals=presampled_goals,
                                 **variant.get('image_env_kwargs', {}))
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    presampled_goals=presampled_goals,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            print("Presampling all goals only")
        else:
            if type(vae) is CVAE or type(vae) is CDVAE or type(
                    vae) is ACE or type(vae) is CADVAE or type(
                        vae) is DeltaCVAE:
                vae_env = ConditionalVAEWrappedEnv(
                    image_env,
                    vae,
                    imsize=image_env.imsize,
                    decode_goals=render,
                    render_goals=render,
                    render_rollouts=render,
                    reward_params=reward_params,
                    **variant.get('vae_wrapped_env_kwargs', {}))
            else:
                vae_env = VAEWrappedEnv(image_env,
                                        vae,
                                        imsize=image_env.imsize,
                                        decode_goals=render,
                                        render_goals=render,
                                        render_rollouts=render,
                                        reward_params=reward_params,
                                        **variant.get('vae_wrapped_env_kwargs',
                                                      {}))
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs'])
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    return env
示例#22
0
def train_reprojection_network_and_update_variant(variant):
    from railrl.core import logger
    from railrl.misc.asset_loader import load_local_or_remote_file
    import railrl.torch.pytorch_util as ptu

    rl_variant = variant.get("rl_variant", {})
    vae_wrapped_env_kwargs = rl_variant.get('vae_wrapped_env_kwargs', {})
    if vae_wrapped_env_kwargs.get("use_reprojection_network", False):
        train_reprojection_network_variant = variant.get(
            "train_reprojection_network_variant", {})

        if train_reprojection_network_variant.get("use_cached_network", False):
            vae_path = rl_variant.get("vae_path", None)
            reprojection_network = load_local_or_remote_file(
                osp.join(vae_path, 'reproj_network.pkl'))
            from railrl.core import logger
            logger.save_extra_data(reprojection_network,
                                   'reproj_network.pkl',
                                   mode='pickle')

            if ptu.gpu_enabled():
                reprojection_network.cuda()

            vae_wrapped_env_kwargs[
                'reprojection_network'] = reprojection_network
        else:
            logger.remove_tabular_output('progress.csv',
                                         relative_to_snapshot_dir=True)
            logger.add_tabular_output('reproj_progress.csv',
                                      relative_to_snapshot_dir=True)

            vae_path = rl_variant.get("vae_path", None)
            ckpt = rl_variant.get("ckpt", None)

            if type(ckpt) is str:
                vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl'))
                from railrl.core import logger

                logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
            elif type(vae_path) is str:
                vae = load_local_or_remote_file(
                    osp.join(vae_path, 'vae_params.pkl'))
                from railrl.core import logger

                logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
            else:
                vae = vae_path

            if type(vae) is str:
                vae = load_local_or_remote_file(vae)
            else:
                vae = vae

            if ptu.gpu_enabled():
                vae.cuda()

            train_reprojection_network_variant['vae'] = vae
            reprojection_network = train_reprojection_network(
                train_reprojection_network_variant)
            vae_wrapped_env_kwargs[
                'reprojection_network'] = reprojection_network
示例#23
0
def generate_vae_dataset(variant):
    print(variant)
    env_class = variant.get('env_class', None)
    env_kwargs = variant.get('env_kwargs', None)
    env_id = variant.get('env_id', None)
    N = variant.get('N', 10000)
    batch_size = variant.get('batch_size', 128)
    test_p = variant.get('test_p', 0.9)
    use_cached = variant.get('use_cached', True)
    imsize = variant.get('imsize', 84)
    num_channels = variant.get('num_channels', 3)
    show = variant.get('show', False)
    init_camera = variant.get('init_camera', None)
    dataset_path = variant.get('dataset_path', None)
    oracle_dataset_using_set_to_goal = variant.get(
        'oracle_dataset_using_set_to_goal', False)
    random_rollout_data = variant.get('random_rollout_data', False)
    random_rollout_data_set_to_goal = variant.get(
        'random_rollout_data_set_to_goal', True)
    random_and_oracle_policy_data = variant.get(
        'random_and_oracle_policy_data', False)
    random_and_oracle_policy_data_split = variant.get(
        'random_and_oracle_policy_data_split', 0)
    policy_file = variant.get('policy_file', None)
    n_random_steps = variant.get('n_random_steps', 100)
    vae_dataset_specific_env_kwargs = variant.get(
        'vae_dataset_specific_env_kwargs', None)
    save_file_prefix = variant.get('save_file_prefix', None)
    non_presampled_goal_img_is_garbage = variant.get(
        'non_presampled_goal_img_is_garbage', None)

    conditional_vae_dataset = variant.get('conditional_vae_dataset', False)
    use_env_labels = variant.get('use_env_labels', False)
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    enviorment_dataset = variant.get('enviorment_dataset', False)
    save_trajectories = variant.get('save_trajectories', False)
    save_trajectories = save_trajectories or use_linear_dynamics or conditional_vae_dataset

    tag = variant.get('tag', '')

    assert N % n_random_steps == 0, "Fix N/horizon or dataset generation will fail"

    from multiworld.core.image_env import ImageEnv, unormalize_image
    import railrl.torch.pytorch_util as ptu
    from railrl.misc.asset_loader import load_local_or_remote_file
    from railrl.data_management.dataset import (
        TrajectoryDataset,
        ImageObservationDataset,
        EnvironmentDataset,
        ConditionalDynamicsDataset,
        InitialObservationNumpyDataset,
        InfiniteBatchLoader,
    )

    info = {}
    if dataset_path is not None:
        dataset = load_local_or_remote_file(dataset_path)
        dataset = dataset.item()
        N = dataset['observations'].shape[0] * dataset['observations'].shape[1]
        n_random_steps = dataset['observations'].shape[1]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        if save_file_prefix is None:
            save_file_prefix = env_id
        if save_file_prefix is None:
            save_file_prefix = env_class.__name__
        filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format(
            save_file_prefix,
            str(N),
            init_camera.__name__
            if init_camera and hasattr(init_camera, '__name__') else '',
            imsize,
            random_and_oracle_policy_data_split,
            tag,
        )
        if use_cached and osp.isfile(filename):
            dataset = load_local_or_remote_file(filename)
            if conditional_vae_dataset:
                dataset = dataset.item()
            print("loaded data from saved file", filename)
        else:
            now = time.time()

            if env_id is not None:
                import gym
                import multiworld
                multiworld.register_all_envs()
                env = gym.make(env_id)
            else:
                if vae_dataset_specific_env_kwargs is None:
                    vae_dataset_specific_env_kwargs = {}
                for key, val in env_kwargs.items():
                    if key not in vae_dataset_specific_env_kwargs:
                        vae_dataset_specific_env_kwargs[key] = val
                env = env_class(**vae_dataset_specific_env_kwargs)
            if not isinstance(env, ImageEnv):
                env = ImageEnv(
                    env,
                    imsize,
                    init_camera=init_camera,
                    transpose=True,
                    normalize=True,
                    non_presampled_goal_img_is_garbage=
                    non_presampled_goal_img_is_garbage,
                )
            else:
                imsize = env.imsize
                env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage
            env.reset()
            info['env'] = env
            if random_and_oracle_policy_data:
                policy_file = load_local_or_remote_file(policy_file)
                policy = policy_file['policy']
                policy.to(ptu.device)
            if random_rollout_data:
                from railrl.exploration_strategies.ou_strategy import OUStrategy
                policy = OUStrategy(env.action_space)

            if save_trajectories:
                dataset = {
                    'observations':
                    np.zeros((N // n_random_steps, n_random_steps,
                              imsize * imsize * num_channels),
                             dtype=np.uint8),
                    'actions':
                    np.zeros((N // n_random_steps, n_random_steps,
                              env.action_space.shape[0]),
                             dtype=np.float),
                    'env':
                    np.zeros(
                        (N // n_random_steps, imsize * imsize * num_channels),
                        dtype=np.uint8),
                }
            else:
                dataset = np.zeros((N, imsize * imsize * num_channels),
                                   dtype=np.uint8)
            labels = []
            for i in range(N):
                if random_and_oracle_policy_data:
                    num_random_steps = int(N *
                                           random_and_oracle_policy_data_split)
                    if i < num_random_steps:
                        env.reset()
                        for _ in range(n_random_steps):
                            obs = env.step(env.action_space.sample())[0]
                    else:
                        obs = env.reset()
                        policy.reset()
                        for _ in range(n_random_steps):
                            policy_obs = np.hstack((
                                obs['state_observation'],
                                obs['state_desired_goal'],
                            ))
                            action, _ = policy.get_action(policy_obs)
                            obs, _, _, _ = env.step(action)
                elif random_rollout_data:  #ADD DATA WHERE JUST PUCK MOVES
                    if i % n_random_steps == 0:
                        env.reset()
                        policy.reset()
                        env_img = env._get_obs()['image_observation']
                        if random_rollout_data_set_to_goal:
                            env.set_to_goal(env.get_goal())
                    obs = env._get_obs()
                    u = policy.get_action_from_raw_action(
                        env.action_space.sample())
                    env.step(u)
                elif oracle_dataset_using_set_to_goal:
                    print(i)
                    goal = env.sample_goal()
                    env.set_to_goal(goal)
                    obs = env._get_obs()
                else:
                    env.reset()
                    for _ in range(n_random_steps):
                        obs = env.step(env.action_space.sample())[0]

                img = obs['image_observation']
                if use_env_labels:
                    labels.append(obs['label'])
                if save_trajectories:
                    dataset['observations'][
                        i // n_random_steps,
                        i % n_random_steps, :] = unormalize_image(img)
                    dataset['actions'][i // n_random_steps,
                                       i % n_random_steps, :] = u
                    dataset['env'][i // n_random_steps, :] = unormalize_image(
                        env_img)
                else:
                    dataset[i, :] = unormalize_image(img)

                if show:
                    img = img.reshape(3, imsize, imsize).transpose()
                    img = img[::-1, :, ::-1]
                    cv2.imshow('img', img)
                    cv2.waitKey(1)
                    # radius = input('waiting...')
            print("done making training data", filename, time.time() - now)
            np.save(filename, dataset)
            #np.save(filename[:-4] + 'labels.npy', np.array(labels))

    info['train_labels'] = []
    info['test_labels'] = []

    if use_linear_dynamics and conditional_vae_dataset:
        num_trajectories = N // n_random_steps
        n = int(num_trajectories * test_p)
        train_dataset = ConditionalDynamicsDataset({
            'observations':
            dataset['observations'][:n, :, :],
            'actions':
            dataset['actions'][:n, :, :],
            'env':
            dataset['env'][:n, :]
        })
        test_dataset = ConditionalDynamicsDataset({
            'observations':
            dataset['observations'][n:, :, :],
            'actions':
            dataset['actions'][n:, :, :],
            'env':
            dataset['env'][n:, :]
        })

        num_trajectories = N // n_random_steps
        n = int(num_trajectories * test_p)
        indices = np.arange(num_trajectories)
        np.random.shuffle(indices)
        train_i, test_i = indices[:n], indices[n:]

        try:
            train_dataset = ConditionalDynamicsDataset({
                'observations':
                dataset['observations'][train_i, :, :],
                'actions':
                dataset['actions'][train_i, :, :],
                'env':
                dataset['env'][train_i, :]
            })
            test_dataset = ConditionalDynamicsDataset({
                'observations':
                dataset['observations'][test_i, :, :],
                'actions':
                dataset['actions'][test_i, :, :],
                'env':
                dataset['env'][test_i, :]
            })
        except:
            train_dataset = ConditionalDynamicsDataset({
                'observations':
                dataset['observations'][train_i, :, :],
                'actions':
                dataset['actions'][train_i, :, :],
            })
            test_dataset = ConditionalDynamicsDataset({
                'observations':
                dataset['observations'][test_i, :, :],
                'actions':
                dataset['actions'][test_i, :, :],
            })
    elif use_linear_dynamics:
        num_trajectories = N // n_random_steps
        n = int(num_trajectories * test_p)
        train_dataset = TrajectoryDataset({
            'observations':
            dataset['observations'][:n, :, :],
            'actions':
            dataset['actions'][:n, :, :]
        })
        test_dataset = TrajectoryDataset({
            'observations':
            dataset['observations'][n:, :, :],
            'actions':
            dataset['actions'][n:, :, :]
        })
    elif enviorment_dataset:
        n = int(n_random_steps * test_p)
        train_dataset = EnvironmentDataset({
            'observations':
            dataset['observations'][:, :n, :],
        })
        test_dataset = EnvironmentDataset({
            'observations':
            dataset['observations'][:, n:, :],
        })
    elif conditional_vae_dataset:
        num_trajectories = N // n_random_steps
        n = int(num_trajectories * test_p)
        indices = np.arange(num_trajectories)
        np.random.shuffle(indices)
        train_i, test_i = indices[:n], indices[n:]

        if 'env' in dataset:
            train_dataset = InitialObservationNumpyDataset({
                'observations':
                dataset['observations'][train_i, :, :],
                'env':
                dataset['env'][train_i, :]
            })
            test_dataset = InitialObservationNumpyDataset({
                'observations':
                dataset['observations'][test_i, :, :],
                'env':
                dataset['env'][test_i, :]
            })
        else:
            train_dataset = InitialObservationNumpyDataset({
                'observations':
                dataset['observations'][train_i, :, :],
            })
            test_dataset = InitialObservationNumpyDataset({
                'observations':
                dataset['observations'][test_i, :, :],
            })

        train_batch_loader_kwargs = variant.get(
            'train_batch_loader_kwargs',
            dict(
                batch_size=batch_size,
                num_workers=0,
            ))
        test_batch_loader_kwargs = variant.get(
            'test_batch_loader_kwargs',
            dict(
                batch_size=batch_size,
                num_workers=0,
            ))

        train_data_loader = data.DataLoader(train_dataset,
                                            shuffle=True,
                                            drop_last=True,
                                            **train_batch_loader_kwargs)
        test_data_loader = data.DataLoader(test_dataset,
                                           shuffle=True,
                                           drop_last=True,
                                           **test_batch_loader_kwargs)

        train_dataset = InfiniteBatchLoader(train_data_loader)
        test_dataset = InfiniteBatchLoader(test_data_loader)
    else:
        n = int(N * test_p)
        train_dataset = ImageObservationDataset(dataset[:n, :])
        test_dataset = ImageObservationDataset(dataset[n:, :])
    return train_dataset, test_dataset, info
示例#24
0
    def __init__(
        self,
        wrapped_env,
        vae,
        reward_params=None,
        config_params=None,
        imsize=84,
        obs_size=None,
        vae_input_observation_key="image_observation",
        small_image_step=6,
    ):
        if config_params is None:
            config_params = dict
        if reward_params is None:
            reward_params = dict()
        super().__init__(wrapped_env)
        if type(vae) is str:
            self.vae = load_local_or_remote_file(vae)
        else:
            self.vae = vae
        self.representation_size = self.vae.representation_size
        self.input_channels = self.vae.input_channels
        self.imsize = imsize
        self.config_params = config_params
        self.t = 0
        self.episode_num = 0
        self.reward_params = reward_params
        self.reward_type = self.reward_params.get("type", 'latent_distance')
        self.zT = self.reward_params["goal_latent"]
        self.z0 = self.reward_params["initial_latent"]
        self.dT = self.zT - self.z0

        self.small_image_step = small_image_step
        # if self.config_params["use_initial"]:
        #     self.dT = self.zT - self.z0
        # else:
        #     self.dT = self.zT

        self.vae_input_observation_key = vae_input_observation_key

        latent_size = obs_size or self.representation_size
        latent_space = Box(
            -10 * np.ones(latent_size),
            10 * np.ones(latent_size),
            dtype=np.float32,
        )
        goal_space = Box(
            np.zeros((0, )),
            np.zeros((0, )),
            dtype=np.float32,
        )
        spaces = self.wrapped_env.observation_space.spaces
        spaces['observation'] = latent_space
        spaces['desired_goal'] = goal_space
        spaces['achieved_goal'] = goal_space
        spaces['latent_observation'] = latent_space
        spaces['latent_desired_goal'] = goal_space
        spaces['latent_achieved_goal'] = goal_space

        concat_size = latent_size + spaces["state_observation"].low.size
        concat_space = Box(
            -10 * np.ones(concat_size),
            10 * np.ones(concat_size),
            dtype=np.float32,
        )
        spaces['concat_observation'] = concat_space
        small_image_size = 288 // self.small_image_step
        small_image_imglength = small_image_size * small_image_size * 3
        small_image_space = Box(
            0 * np.ones(small_image_imglength),
            1 * np.ones(small_image_imglength),
            dtype=np.float32,
        )
        spaces['small_image_observation'] = small_image_space
        small_image_observation_with_state_size = small_image_imglength + spaces[
            "state_observation"].low.size
        small_image_observation_with_state_space = Box(
            0 * np.ones(small_image_observation_with_state_size),
            1 * np.ones(small_image_observation_with_state_size),
            dtype=np.float32,
        )
        spaces[
            'small_image_observation_with_state'] = small_image_observation_with_state_space

        self.observation_space = Dict(spaces)
示例#25
0
from multiworld.core.image_env import ImageEnv
import multiworld.envs.mujoco as mwmj
from multiworld.envs.mujoco.cameras import sawyer_door_env_camera_v0

from railrl.demos.collect_demo import collect_demos
import os.path as osp

from railrl.misc.asset_loader import load_local_or_remote_file

if __name__ == '__main__':
    data = load_local_or_remote_file(
        '11-16-door-reset-free-state-td3-sweep-params-policy-update-period/11-16-door_reset_free_state_td3_sweep_params_policy_update_period_2019_11_17_00_26_50_id000--s89728/params.pkl'
    )
    env = data['evaluation/env']
    policy = data['trainer/trained_policy']
    presampled_goals_path = osp.join(
        osp.dirname(mwmj.__file__),
        "goals",
        "door_goals.npy",
    )
    presampled_goals = load_local_or_remote_file(presampled_goals_path).item()
    image_env = ImageEnv(
        env,
        48,
        init_camera=sawyer_door_env_camera_v0,
        transpose=True,
        normalize=True,
        presampled_goals=presampled_goals,
    )
    collect_demos(image_env,
                  policy,
    # data = load_local_or_remote_file(
    # '/home/murtaza/research/railrl/data/local/03-04-bc-hc-v2/03-04-bc_hc_v2_2020_03_04_17_57_54_id000--s90897/bc.pkl')
    # env = gym.make('Ant-v2')
    # policy = data.cpu()
    # collect_demos_fixed(env, policy, "data/local/demos/ant_off_policy_100.npy", N=100, horizon=1000, threshold=8000,
    # render=False)
    # data = load_local_or_remote_file(
    # '/home/murtaza/research/railrl/data/doodads3/03-08-bc-ant-gym-v1/03-08-bc_ant_gym_v1_2020_03_08_19_22_00_id000--s39483/bc.pkl')
    # # env = gym.make('Ant-v2')
    # policy = MakeDeterministic(data.cpu())
    # collect_demos_fixed(env, policy, "data/local/demos/ant_off_policy_10_demos_100.npy", N=100, horizon=1000, threshold=-1,
    # render=False)

    data = load_local_or_remote_file(
        '/home/murtaza/research/railrl/data/local/03-09-bc-ant-frac-trajs-sweep/03-09-bc_ant_frac_trajs_sweep_2020_03_09_17_58_01_id000--s71624/bc.pkl'
    )
    env = gym.make('Ant-v2')
    policy = data.cpu()
    collect_demos_fixed(env,
                        policy,
                        "data/local/demos/ant_off_policy_10_demos_100.npy",
                        N=100,
                        horizon=1000,
                        threshold=-1,
                        render=False)

    data = load_local_or_remote_file(
        '/home/murtaza/research/railrl/data/local/03-09-bc-ant-frac-trajs-sweep/03-09-bc_ant_frac_trajs_sweep_2020_03_09_17_58_02_id000--s47768/bc.pkl'
    )
    env = gym.make('Ant-v2')
示例#27
0
def _disentangled_grill_her_twin_sac_experiment(
        max_path_length,
        encoder_kwargs,
        disentangled_qf_kwargs,
        qf_kwargs,
        twin_sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        vae_evaluation_goal_sampling_mode,
        vae_exploration_goal_sampling_mode,
        base_env_evaluation_goal_sampling_mode,
        base_env_exploration_goal_sampling_mode,
        algo_kwargs,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        latent_dim=2,
        vae_wrapped_env_kwargs=None,
        vae_path=None,
        vae_n_vae_training_kwargs=None,
        vectorized=False,
        save_video=True,
        save_video_kwargs=None,
        have_no_disentangled_encoder=False,
        **kwargs):
    if env_kwargs is None:
        env_kwargs = {}
    assert env_id or env_class

    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    train_env.goal_sampling_mode = base_env_exploration_goal_sampling_mode
    eval_env.goal_sampling_mode = base_env_evaluation_goal_sampling_mode

    if vae_path:
        vae = load_local_or_remote_file(vae_path)
    else:
        vae = get_n_train_vae(latent_dim=latent_dim,
                              env=eval_env,
                              **vae_n_vae_training_kwargs)

    train_env = VAEWrappedEnv(train_env,
                              vae,
                              imsize=train_env.imsize,
                              **vae_wrapped_env_kwargs)
    eval_env = VAEWrappedEnv(eval_env,
                             vae,
                             imsize=train_env.imsize,
                             **vae_wrapped_env_kwargs)

    obs_dim = train_env.observation_space.spaces[observation_key].low.size
    goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size
    action_dim = train_env.action_space.low.size

    encoder = FlattenMlp(input_size=obs_dim,
                         output_size=latent_dim,
                         **encoder_kwargs)

    def make_qf():
        if have_no_disentangled_encoder:
            return FlattenMlp(
                input_size=obs_dim + goal_dim + action_dim,
                output_size=1,
                **qf_kwargs,
            )
        else:
            return DisentangledMlpQf(goal_processor=encoder,
                                     preprocess_obs_dim=obs_dim,
                                     action_dim=action_dim,
                                     qf_kwargs=qf_kwargs,
                                     vectorized=vectorized,
                                     **disentangled_qf_kwargs)

    qf1 = make_qf()
    qf2 = make_qf()
    target_qf1 = make_qf()
    target_qf2 = make_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        vectorized=vectorized,
        **replay_buffer_kwargs)
    sac_trainer = SACTrainer(env=train_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **twin_sac_trainer_kwargs)
    trainer = HERTrainer(sac_trainer)

    eval_path_collector = VAEWrappedEnvPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=vae_evaluation_goal_sampling_mode,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        train_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=vae_exploration_goal_sampling_mode,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True)

        if have_no_disentangled_encoder:

            def v_function(obs):
                action = policy.get_actions(obs)
                obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
                return qf1(obs, action)

            add_heatmap = partial(add_heatmap_img_to_o_dict,
                                  v_function=v_function)
        else:

            def v_function(obs):
                action = policy.get_actions(obs)
                obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
                return qf1(obs, action, return_individual_q_vals=True)

            add_heatmap = partial(
                add_heatmap_imgs_to_o_dict,
                v_function=v_function,
                vectorized=vectorized,
            )
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            full_o_postprocess_func=add_heatmap if save_vf_heatmap else None,
        )
        img_keys = ['v_vals'] + [
            'v_vals_dim_{}'.format(dim) for dim in range(latent_dim)
        ]
        eval_video_func = get_save_video_function(rollout_function,
                                                  eval_env,
                                                  MakeDeterministic(policy),
                                                  get_extra_imgs=partial(
                                                      get_extra_imgs,
                                                      img_keys=img_keys),
                                                  tag="eval",
                                                  **save_video_kwargs)
        train_video_func = get_save_video_function(rollout_function,
                                                   train_env,
                                                   policy,
                                                   get_extra_imgs=partial(
                                                       get_extra_imgs,
                                                       img_keys=img_keys),
                                                   tag="train",
                                                   **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)
    algorithm.train()
示例#28
0
from multiworld.core.image_env import ImageEnv
from multiworld.envs.mujoco.cameras import sawyer_init_camera_zoomed_in
import numpy as np
from railrl.demos.collect_demo import collect_demos
from railrl.misc.asset_loader import load_local_or_remote_file

if __name__ == '__main__':
    data = load_local_or_remote_file(
        'ashvin/masks/pusher/state5/run10/id10/itr_980.pkl')
    env = data['evaluation/env']
    policy = data['evaluation/policy']
    # import ipdb; ipdb.set_trace()
    # policy =
    policy.to("cpu")
    # image_env = ImageEnv(
    #     env,
    #     48,
    #     init_camera=sawyer_init_camera_zoomed_in,
    #     transpose=True,
    #     normalize=True,
    # )
    # env_name = pendulum
    outfile = "/home/ashvin/data/s3doodad/demos/icml2020/pusher/demos100.npy"
    horizon = 200
    collect_demos(
        env, policy, outfile, N=100, horizon=horizon
    )  # , threshold=.1, add_action_noise=False, key='puck_distance', render=True, noise_sigma=0.0)
    # data = load_local_or_remote_file("demos/pusher_demos_1000.npy")
    # for i in range(100):
    #     goal = data[i]['observations'][49]['desired_goal']
    #     o = env.reset()
示例#29
0
    def __init__(
        self,
        wrapped_env,
        vae,
        vae_input_key_prefix='image',
        use_vae_goals=True,
        sample_from_true_prior=False,
        decode_goals=False,
        render_goals=False,
        render_rollouts=False,
        reward_params=None,
        mode="train",
        imsize=84,
        obs_size=None,
        norm_order=2,
        epsilon=20,
        temperature=1.0,
        vis_granularity=50,
        presampled_goals=None,
        train_noisy_encoding=False,
        test_noisy_encoding=False,
        noise_type=None,  #DEPRECATED
        num_samples_for_latent_histogram=10000,
        use_reprojection_network=False,
        reprojection_network=None,
        use_vae_dataset=True,
        vae_dataset_path=None,
        clip_encoding_std=True,
        use_replay_buffer_goals=False,  #DEPRECATED FEATURE
        v_func_heatmap_bounds=(-1.5, 0.0),
        reproj_vae=None,
        disable_annotated_images=False,
    ):
        self.quick_init(locals())
        if reward_params is None:
            reward_params = dict()
        super().__init__(wrapped_env)
        if type(vae) is str:
            self.vae = load_local_or_remote_file(vae)
        else:
            self.vae = vae
        if ptu.gpu_enabled():
            vae.cuda()

        if reproj_vae is not None:
            if type(reproj_vae) is str:
                self.reproj_vae = load_local_or_remote_file(reproj_vae)
            else:
                self.reproj_vae = reproj_vae
            if ptu.gpu_enabled():
                self.reproj_vae.cuda()
        else:
            self.reproj_vae = None

        self.representation_size = self.vae.representation_size
        if hasattr(self.vae, 'input_channels'):
            self.input_channels = self.vae.input_channels
        else:
            self.input_channels = None
        self._reconstr_image_observation = False
        self._use_vae_goals = use_vae_goals
        self.sample_from_true_prior = sample_from_true_prior
        self.decode_goals = decode_goals
        self.render_goals = render_goals
        self.render_rollouts = render_rollouts
        self.default_kwargs = dict(
            decode_goals=decode_goals,
            render_goals=render_goals,
            render_rollouts=render_rollouts,
        )

        self.clip_encoding_std = clip_encoding_std

        self.train_noisy_encoding = train_noisy_encoding
        self.test_noisy_encoding = test_noisy_encoding
        self.noise_type = noise_type

        self.imsize = imsize
        self.vis_granularity = vis_granularity  # for heatmaps
        self.reward_params = reward_params
        self.reward_type = self.reward_params.get("type", 'latent_distance')
        self.norm_order = self.reward_params.get("norm_order", norm_order)
        self.epsilon = self.reward_params.get("epsilon",
                                              epsilon)  # for sparse reward
        self.temperature = self.reward_params.get(
            "temperature", temperature)  # for exponential reward
        self.reward_min_variance = self.reward_params.get("min_variance", 0)
        latent_space = Box(
            -10 * np.ones(obs_size or self.representation_size),
            10 * np.ones(obs_size or self.representation_size),
            dtype=np.float32,
        )

        spaces = self.wrapped_env.observation_space.spaces
        spaces['observation'] = latent_space
        spaces['desired_goal'] = latent_space
        spaces['achieved_goal'] = latent_space

        spaces['latent_observation'] = latent_space
        spaces['latent_observation_mean'] = latent_space
        spaces['latent_observation_std'] = latent_space

        spaces['latent_desired_goal'] = latent_space
        spaces['latent_desired_goal_mean'] = latent_space
        spaces['latent_desired_goal_std'] = latent_space

        spaces['latent_achieved_goal'] = latent_space
        spaces['latent_achieved_goal_mean'] = latent_space
        spaces['latent_achieved_goal_std'] = latent_space

        self.observation_space = Dict(spaces)
        self.mode(mode)

        self.vae_input_key_prefix = vae_input_key_prefix
        assert vae_input_key_prefix in set(['image', 'image_proprio', 'state'])
        self.vae_input_observation_key = vae_input_key_prefix + '_observation'
        self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal'
        self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal'

        self._presampled_goals = presampled_goals
        if self._presampled_goals is None:
            self.num_goals_presampled = 0
        else:
            self.num_goals_presampled = presampled_goals[random.choice(
                list(presampled_goals))].shape[0]
            self._presampled_latent_goals, self._presampled_latent_goals_mean, self._presampled_latent_goals_std = \
                self._encode(
                    self._presampled_goals[self.vae_input_desired_goal_key],
                    noisy=self.noisy_encoding,
                    batch_size=2500
                )

        self._mode_map = {}
        self.desired_goal = {}
        self.desired_goal['latent_desired_goal'] = latent_space.sample()
        self._initial_obs = None

        self.latent_subgoals = None
        self.latent_subgoals_reproj = None
        self.subgoal_v_vals = None
        self.image_subgoals = None
        self.image_subgoals_stitched = None
        self.image_subgoals_reproj_stitched = None
        self.updated_image_subgoals = False

        self.sweep_goal_mu = None
        self.wrapped_env.reset()
        self.sweep_goal_latents = None

        self.use_reprojection_network = use_reprojection_network
        self.reprojection_network = reprojection_network

        self.use_vae_dataset = use_vae_dataset
        self.vae_dataset_path = vae_dataset_path

        self.num_samples_for_latent_histogram = num_samples_for_latent_histogram
        self.latent_histogram = None
        self.latent_histogram_noisy = None

        self.num_active_dims = 0
        for std in self.vae.dist_std:
            if std > 0.15:
                self.num_active_dims += 1

        self.active_dims = self.vae.dist_std.argsort(
        )[-self.num_active_dims:][::-1]
        self.inactive_dims = self.vae.dist_std.argsort(
        )[:-self.num_active_dims][::-1]

        self.mu = None
        self.std = None
        self.prior_distr = None

        self.v_func_heatmap_bounds = v_func_heatmap_bounds

        self.vis_blacklist = []
        self.disable_annotated_images = disable_annotated_images
示例#30
0
def HER_baseline_td3_experiment(variant):
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.torch.her.her_td3 import HerTd3
    from railrl.torch.networks import MergedCNN, CNNPolicy
    import torch
    from multiworld.core.image_env import ImageEnv
    from railrl.misc.asset_loader import load_local_or_remote_file

    init_camera = variant.get("init_camera", None)
    presample_goals = variant.get('presample_goals', False)
    presampled_goals_path = get_presampled_goals_path(
        variant.get('presampled_goals_path', None))

    if 'env_id' in variant:
        import gym
        import multiworld
        multiworld.register_all_envs()
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    image_env = ImageEnv(
        env,
        variant.get('imsize'),
        reward_type='image_sparse',
        init_camera=init_camera,
        transpose=True,
        normalize=True,
    )
    if presample_goals:
        if presampled_goals_path is None:
            image_env.non_presampled_goal_img_is_garbage = True
            presampled_goals = variant['generate_goal_dataset_fctn'](
                env=image_env, **variant['goal_generation_kwargs'])
        else:
            presampled_goals = load_local_or_remote_file(
                presampled_goals_path).item()
        del image_env
        env = ImageEnv(
            env,
            variant.get('imsize'),
            reward_type='image_distance',
            init_camera=init_camera,
            transpose=True,
            normalize=True,
            presampled_goals=presampled_goals,
        )
    else:
        env = image_env

    es = get_exploration_strategy(variant, env)

    observation_key = variant.get('observation_key', 'image_observation')
    desired_goal_key = variant.get('desired_goal_key', 'image_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    imsize = variant['imsize']
    action_dim = env.action_space.low.size
    qf1 = MergedCNN(input_width=imsize,
                    input_height=imsize,
                    output_size=1,
                    input_channels=3 * 2,
                    added_fc_input_size=action_dim,
                    **variant['cnn_params'])
    qf2 = MergedCNN(input_width=imsize,
                    input_height=imsize,
                    output_size=1,
                    input_channels=3 * 2,
                    added_fc_input_size=action_dim,
                    **variant['cnn_params'])

    policy = CNNPolicy(
        input_width=imsize,
        input_height=imsize,
        added_fc_input_size=0,
        output_size=action_dim,
        input_channels=3 * 2,
        output_activation=torch.tanh,
        **variant['cnn_params'],
    )
    target_qf1 = MergedCNN(input_width=imsize,
                           input_height=imsize,
                           output_size=1,
                           input_channels=3 * 2,
                           added_fc_input_size=action_dim,
                           **variant['cnn_params'])
    target_qf2 = MergedCNN(input_width=imsize,
                           input_height=imsize,
                           output_size=1,
                           input_channels=3 * 2,
                           added_fc_input_size=action_dim,
                           **variant['cnn_params'])

    target_policy = CNNPolicy(
        input_width=imsize,
        input_height=imsize,
        added_fc_input_size=0,
        output_size=action_dim,
        input_channels=3 * 2,
        output_activation=torch.tanh,
        **variant['cnn_params'],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant["render"]
    base_kwargs['render_during_eval'] = variant["render"]
    her_kwargs = algo_kwargs['her_kwargs']
    her_kwargs['observation_key'] = observation_key
    her_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = HerTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       target_qf1=target_qf1,
                       target_qf2=target_qf2,
                       target_policy=target_policy,
                       exploration_policy=exploration_policy,
                       **variant['algo_kwargs'])

    algorithm.to(ptu.device)
    algorithm.train()