Esempio n. 1
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=False,
    imsize=84,
    show=False,
    dataset_path=None,
    env_class=SawyerReachTorqueEnv,
    env_kwargs=None,
    init_camera=sawyer_torque_reacher_camera,
):

    filename = "/tmp/sawyer_torque_data" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        if env_kwargs == None:
            env_kwargs = dict()
        env = env_class(**env_kwargs)
        env = ImageEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        info['env'] = env
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        dataset = np.zeros((N, imsize * imsize * 3), dtype=np.uint8)
        for i in range(N):
            if i % 50 == 0:
                print('Reset')
                env.reset_model()
                exploration_policy.reset()
            for _ in range(1):
                action = exploration_policy.get_action()[0] * 1 / 10
                env.wrapped_env.step(action)
            img = env._get_flat_img()
            dataset[i, :] = unormalize_image(img)
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        print("done making training data", time.time() - now)
        np.save(filename, dataset)
    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 2
0
def generate_vae_dataset(
        N=10000, test_p=0.9, use_cached=True, imsize=84, show=False,
        dataset_path=None, env_class=None, env_kwargs=None, init_camera=sawyer_door_env_camera,
):
    filename = "/tmp/sawyer_door_push_open_and_reach" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        env = env_class(**env_kwargs)
        env =  ImageEnv(
            env, imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        oracle_sampled_data = int(N/2)
        dataset = np.zeros((N, imsize * imsize * 3))
        print('Goal Space Sampling')
        for i in range(oracle_sampled_data):
            goal = env.sample_goal()
            env.set_to_goal(goal)
            img = env._get_flat_img()
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        env._wrapped_env.min_y_pos=.6
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        print('Random Sampling')
        for i in range(oracle_sampled_data, N):
            if i % 20==0:
                env.reset()
                exploration_policy.reset()
            for _ in range(10):
                action = exploration_policy.get_action()[0]
                env.wrapped_env.step(
                    action
                )
            img = env._get_flat_img()
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 3
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    init_camera=sawyer_init_camera_zoomed_in,
    dataset_path=None,
    env_kwargs=None,
):
    if env_kwargs is None:
        env_kwargs = {}
    filename = "/tmp/sawyer_push_variable{}_{}.npy".format(
        str(N),
        init_camera.__name__,
    )
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
        N = dataset.shape[0]
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        env = SawyerPushXYVariableEnv(hide_goal=True, **env_kwargs)
        env = ImageMujocoEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        info['env'] = env

        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            goal = env.sample_goal_for_rollout()
            hand_pos = env.sample_hand_xy()
            env.set_to_goal(goal, reset_hand=False)
            env.set_hand_xy(hand_pos)
            # img = env.reset()
            img = env.step(env.action_space.sample())[0]
            dataset[i, :] = img
            if show:
                img = img.reshape(3, 84, 84).transpose()
                img = img[::-1, :, ::-1]
                cv2.imshow('img', img)
                cv2.waitKey(1)
                # radius = input('waiting...')
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 4
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    init_camera=sawyer_init_camera_zoomed_in,
    dataset_path=None,
    env_kwargs=None,
):
    """
    Oracle means that we use `set_to_goal` rather than doing random rollouts.
    """
    if env_kwargs is None:
        env_kwargs = {}
    filename = "/tmp/sawyer_reset_free_push{}_{}.npy".format(
        str(N),
        init_camera.__name__,
    )
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
        N = dataset.shape[0]
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        env = SawyerResetFreePushEnv(hide_goal=True, **env_kwargs)
        env = ImageMujocoEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        info['env'] = env

        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            goal = env.sample_goal_for_rollout()
            env.set_to_goal(goal)
            img = env.reset()
            dataset[i, :] = img
            if show:
                img = img.reshape(3, 84, 84).transpose()
                img = img[::-1, :, ::-1]
                cv2.imshow('img', img)
                cv2.waitKey(1)
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 5
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    init_camera=sawyer_init_camera_zoomed_in,
    dataset_path=None,
):
    filename = "/tmp/sawyer_push_new_easy{}_{}.npy".format(
        str(N),
        init_camera.__name__,
    )
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
        N = dataset.shape[0]
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        env = SawyerPushXYEasyEnv(hide_goal=True)
        env = ImageMujocoEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        info['env'] = env

        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            env.reset()
            for _ in range(100):
                action = env.wrapped_env.action_space.sample()
                # action[0] = 0
                # action[1] = 1
                env.wrapped_env.step(action)
            img = env.step(env.action_space.sample())[0]
            dataset[i, :] = img
            print(i)
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 6
0
        def update_networks(algo, epoch):
            if 'ckpt_epoch' in variant:
                return

            if epoch % algo._eval_epoch_freq == 0:
                filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
                print("Loading ckpt from", filename)
                data = torch.load(filename)#, map_location='cuda:1')
                eval_policy = data['evaluation/policy']
                eval_policy.to(ptu.device)
                algo.eval_data_collector._policy = eval_policy
                for collector in addl_collectors:
                    collector._policy = eval_policy
Esempio n. 7
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    dataset_path=None,
):
    filename = "/tmp/sawyer_push_new_easy_wider2_" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        env = SawyerPushXYEasyEnv(hide_goal=True)
        env = ImageMujocoEnv(
            env,
            imsize,
            transpose=True,
            init_camera=sawyer_init_camera_zoomed_in,
            # init_camera=sawyer_init_camera,
            normalize=True,
        )
        info['env'] = env
        policy = OUStrategy(env.action_space)

        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            # env.reset()
            if i % 100 == 0:
                g = env.sample_goal_for_rollout()
                env.set_goal(g)
                policy.reset()
            u = policy.get_action_from_raw_action(env.action_space.sample())
            img = env.step(u)[0]
            dataset[i, :] = img
            if show:
                # env.render()
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 8
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    dataset_path=None,
):
    filename = "/tmp/sawyer_xy_pos_control_imgs" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        env = SawyerReachXYEnv(hide_goal_markers=True)
        env = ImageEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_sawyer_camera_v1,
            normalize=True,
        )
        info['env'] = env
        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            # Move the goal out of the image
            env.reset()
            for _ in range(50):
                env.wrapped_env.step(env.wrapped_env.action_space.sample())
            img = env.step(env.action_space.sample())[0]['image_observation']

            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 9
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    dataset_path=None,
    policy_path=None,
    action_space_sampling=False,
    env_class=SawyerPushAndPullDoorEnv,
    env_kwargs=None,
    action_plus_random_sampling=False,
    init_camera=sawyer_door_env_camera,
    ratio_action_sample_to_random=1 / 2,
    env_id=None,
):
    if policy_path is not None:
        filename = "/tmp/sawyer_door_push_and_pull_open_oracle+random_policy_data_closer_zoom_action_limited" + str(
            N) + ".npy"
    elif action_space_sampling:
        filename = "/tmp/sawyer_door_push_and_pull_open_zoomed_in_action_space_sampling" + str(
            N) + ".npy"
    else:
        filename = "/tmp/sawyer_door_push_and_pull_open" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    elif action_plus_random_sampling:
        if env_id is not None:
            import gym
            env = gym.make(env_id)
        else:
            env = env_class(**env_kwargs)
            env = ImageEnv(
                env,
                imsize,
                transpose=True,
                init_camera=init_camera,
                normalize=True,
            )
        action_sampled_data = int(N * ratio_action_sample_to_random)
        dataset = np.zeros((N, imsize * imsize * 3), dtype=np.uint8)
        print('Action Space Sampling')
        for i in range(action_sampled_data):
            goal = env.sample_goal()
            env.set_to_goal(goal)
            img = env._get_flat_img()
            dataset[i, :] = unormalize_image(img)
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        print('Random Sampling')
        for i in range(action_sampled_data, N):
            if i % 20 == 0:
                env.reset()
                exploration_policy.reset()
            for _ in range(10):
                action = exploration_policy.get_action()[0]
                env.wrapped_env.step(action)
            goal = env.sample_goal()
            env.set_to_goal_angle(goal['state_desired_goal'])
            img = env._get_flat_img()
            dataset[i, :] = unormalize_image(img)
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        env._wrapped_env.min_y_pos = .5
        info['env'] = env
    else:
        raise NotImplementedError()
    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 10
0
def generate_vae_dataset(
    env_class,
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    init_camera=sawyer_init_camera_zoomed_in,
    dataset_path=None,
    env_kwargs=None,
    oracle_dataset=False,
    n_random_steps=100,
):
    if env_kwargs is None:
        env_kwargs = {}
    filename = "/tmp/{}_{}_{}_oracle{}.npy".format(
        env_class.__name__,
        str(N),
        init_camera.__name__,
        oracle_dataset,
    )
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
        N = dataset.shape[0]
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        env = env_class(**env_kwargs)
        env = ImageEnv(
            env,
            imsize,
            init_camera=init_camera,
            transpose=True,
            normalize=True,
        )
        env.reset()
        info['env'] = env

        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            if oracle_dataset:
                goal = env.sample_goal()
                env.set_to_goal(goal)
            else:
                env.reset()
                for _ in range(n_random_steps):
                    obs = env.step(env.action_space.sample())[0]
            obs = env.step(env.action_space.sample())[0]
            img = obs['image_observation']
            dataset[i, :] = img
            if show:
                img = img.reshape(3, 84, 84).transpose()
                img = img[::-1, :, ::-1]
                cv2.imshow('img', img)
                cv2.waitKey(1)
                # radius = input('waiting...')
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 11
0
def rl_context_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.torch.td3.td3 import TD3 as TD3Trainer
    from rlkit.torch.sac.sac import SACTrainer
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.sac.policies import MakeDeterministic

    preprocess_rl_variant(variant)
    max_path_length = variant['max_path_length']
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key', 'latent_achieved_goal')

    contextual_mdp = variant.get('contextual_mdp', True)
    print("contextual_mdp:", contextual_mdp)

    mask_variant = variant.get('mask_variant', {})
    mask_conditioned = mask_variant.get('mask_conditioned', False)
    print("mask_conditioned:", mask_conditioned)

    if mask_conditioned:
        assert contextual_mdp

    if 'sac' in variant['algorithm'].lower():
        rl_algo = 'sac'
    elif 'td3' in variant['algorithm'].lower():
        rl_algo = 'td3'
    else:
        raise NotImplementedError
    print("RL algorithm:", rl_algo)

    ### load the example dataset, if running checkpoints ###
    if 'ckpt' in variant:
        import os.path as osp
        example_set_variant = variant.get('example_set_variant', dict())
        example_set_variant['use_cache'] = True
        example_set_variant['cache_path'] = osp.join(variant['ckpt'], 'example_dataset.npy')

    if mask_conditioned:
        env = get_envs(variant)
        mask_format = mask_variant['param_variant']['mask_format']
        assert mask_format in ['vector', 'matrix', 'distribution', 'cond_distribution']
        goal_dim = env.observation_space.spaces[desired_goal_key].low.size
        if mask_format in ['vector']:
            context_dim_for_networks = goal_dim + goal_dim
        elif mask_format in ['matrix', 'distribution', 'cond_distribution']:
            context_dim_for_networks = goal_dim + (goal_dim * goal_dim)
        else:
            raise TypeError

        if 'ckpt' in variant:
            from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
            import os.path as osp

            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'masks.npy'))
            masks = np.load(filename, allow_pickle=True)[()]
        else:
            masks = get_mask_params(
                env=env,
                example_set_variant=variant['example_set_variant'],
                param_variant=mask_variant['param_variant'],
            )

        mask_keys = list(masks.keys())
        context_keys = [desired_goal_key] + mask_keys
    else:
        context_keys = [desired_goal_key]


    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_envs(variant)

        if mode == 'expl':
            goal_sampling_mode = variant.get('expl_goal_sampling_mode', None)
        elif mode == 'eval':
            goal_sampling_mode = variant.get('eval_goal_sampling_mode', None)
        if goal_sampling_mode not in [None, 'example_set']:
            env.goal_sampling_mode = goal_sampling_mode

        mask_ids_for_training = mask_variant.get('mask_ids_for_training', None)

        if mask_conditioned:
            context_distrib = MaskDictDistribution(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_format=mask_format,
                masks=masks,
                max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None),
                prev_subtask_weight=mask_variant.get('prev_subtask_weight', None),
                mask_distr=mask_variant.get('train_mask_distr', None),
                mask_ids=mask_ids_for_training,
            )
            reward_fn = ContextualMaskingRewardFn(
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                mask_keys=mask_keys,
                mask_format=mask_format,
                use_g_for_mean=mask_variant['use_g_for_mean'],
                use_squared_reward=mask_variant.get('use_squared_reward', False),
            )
        else:
            if goal_sampling_mode == 'example_set':
                example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant'])
                assert len(example_dataset['list_of_waypoints']) == 1
                from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet
                context_distrib = GoalDictDistributionFromSet(
                    example_dataset['list_of_waypoints'][0],
                    desired_goal_keys=[desired_goal_key],
                )
            else:
                context_distrib = GoalDictDistributionFromMultitaskEnv(
                    env,
                    desired_goal_keys=[desired_goal_key],
                )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None,
        )
        return env, context_distrib, reward_fn

    env, context_distrib, reward_fn = contextual_env_distrib_and_reward(mode='expl')
    eval_env, eval_context_distrib, _ = contextual_env_distrib_and_reward(mode='eval')

    if mask_conditioned:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + context_dim_for_networks
        )
    elif contextual_mdp:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + env.observation_space.spaces[desired_goal_key].low.size
        )
    else:
        obs_dim = env.observation_space.spaces[observation_key].low.size

    action_dim = env.action_space.low.size

    if 'ckpt' in variant and 'ckpt_epoch' in variant:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp

        ckpt_epoch = variant['ckpt_epoch']
        if ckpt_epoch is not None:
            epoch = variant['ckpt_epoch']
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
        else:
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'params.pkl'))
        print("Loading ckpt from", filename)
        data = torch.load(filename)
        qf1 = data['trainer/qf1']
        qf2 = data['trainer/qf2']
        target_qf1 = data['trainer/target_qf1']
        target_qf2 = data['trainer/target_qf2']
        policy = data['trainer/policy']
        eval_policy = data['evaluation/policy']
        expl_policy = data['exploration/policy']
    else:
        qf1 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        qf2 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        target_qf1 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        target_qf2 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        if rl_algo == 'td3':
            policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            target_policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            expl_policy = create_exploration_policy(
                env, policy,
                exploration_version=variant['exploration_type'],
                exploration_noise=variant['exploration_noise'],
            )
            eval_policy = policy
        elif rl_algo == 'sac':
            policy = TanhGaussianPolicy(
                obs_dim=obs_dim,
                action_dim=action_dim,
                **variant['policy_kwargs']
            )
            expl_policy = policy
            eval_policy = MakeDeterministic(policy)

    post_process_mask_fn = partial(
        full_post_process_mask_fn,
        mask_conditioned=mask_conditioned,
        mask_variant=mask_variant,
        context_distrib=context_distrib,
        context_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
    )

    def context_from_obs_dict_fn(obs_dict):
        context_dict = {
            desired_goal_key: obs_dict[achieved_goal_key]
        }

        if mask_conditioned:
            sample_masks_for_relabeling = mask_variant.get('sample_masks_for_relabeling', True)
            if sample_masks_for_relabeling:
                batch_size = next(iter(obs_dict.values())).shape[0]
                sampled_contexts = context_distrib.sample(batch_size)
                for mask_key in mask_keys:
                    context_dict[mask_key] = sampled_contexts[mask_key]
            else:
                for mask_key in mask_keys:
                    context_dict[mask_key] = obs_dict[mask_key]

        return context_dict

    def concat_context_to_obs(batch, replay_buffer=None, obs_dict=None, next_obs_dict=None, new_contexts=None):
        obs = batch['observations']
        next_obs = batch['next_observations']
        batch_size = obs.shape[0]
        if mask_conditioned:
            if obs_dict is not None and new_contexts is not None:
                if not mask_variant.get('relabel_masks', True):
                    for k in mask_keys:
                        new_contexts[k] = next_obs_dict[k][:]
                    batch.update(new_contexts)
                if not mask_variant.get('relabel_goals', True):
                    new_contexts[desired_goal_key] = next_obs_dict[desired_goal_key][:]
                    batch.update(new_contexts)

                new_contexts = post_process_mask_fn(obs_dict, new_contexts)
                batch.update(new_contexts)

            if mask_format in ['vector', 'matrix']:
                goal = batch[desired_goal_key]
                mask = batch['mask'].reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, goal, mask], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, goal, mask], axis=1)
            elif mask_format == 'distribution':
                goal = batch[desired_goal_key]
                sigma_inv = batch['mask_sigma_inv'].reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, goal, sigma_inv], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, goal, sigma_inv], axis=1)
            elif mask_format == 'cond_distribution':
                goal = batch[desired_goal_key]
                mu_w = batch['mask_mu_w']
                mu_g = batch['mask_mu_g']
                mu_A = batch['mask_mu_mat']
                sigma_inv = batch['mask_sigma_inv']
                if mask_variant['use_g_for_mean']:
                    mu_w_given_g = goal
                else:
                    mu_w_given_g = mu_w + np.squeeze(mu_A @ np.expand_dims(goal - mu_g, axis=-1), axis=-1)
                sigma_w_given_g_inv = sigma_inv.reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
            else:
                raise NotImplementedError
        elif contextual_mdp:
            goal = batch[desired_goal_key]
            batch['observations'] = np.concatenate([obs, goal], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, goal], axis=1)
        else:
            batch['observations'] = obs
            batch['next_observations'] = next_obs

        return batch

    if 'observation_keys' not in variant['contextual_replay_buffer_kwargs']:
        variant['contextual_replay_buffer_kwargs']['observation_keys'] = []
    observation_keys = variant['contextual_replay_buffer_kwargs']['observation_keys']
    if observation_key not in observation_keys:
        observation_keys.append(observation_key)
    if achieved_goal_key not in observation_keys:
        observation_keys.append(achieved_goal_key)

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=env,
        context_keys=context_keys,
        context_distribution=context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **variant['contextual_replay_buffer_kwargs']
    )

    if rl_algo == 'td3':
        trainer = TD3Trainer(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['td3_trainer_kwargs']
        )
    elif rl_algo == 'sac':
        trainer = SACTrainer(
            env=env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            **variant['sac_trainer_kwargs']
        )

    def create_path_collector(
            env,
            policy,
            mode='expl',
            mask_kwargs={},
    ):
        assert mode in ['expl', 'eval']

        save_env_in_snapshot = variant.get('save_env_in_snapshot', True)

        if mask_conditioned:
            if 'rollout_mask_order' in mask_kwargs:
                rollout_mask_order = mask_kwargs['rollout_mask_order']
            else:
                if mode == 'expl':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_expl', 'fixed')
                elif mode == 'eval':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_eval', 'fixed')
                else:
                    raise TypeError

            if 'mask_distr' in mask_kwargs:
                mask_distr = mask_kwargs['mask_distr']
            else:
                if mode == 'expl':
                    mask_distr = mask_variant['expl_mask_distr']
                elif mode == 'eval':
                    mask_distr = mask_variant['eval_mask_distr']
                else:
                    raise TypeError

            if 'mask_ids' in mask_kwargs:
                mask_ids = mask_kwargs['mask_ids']
            else:
                if mode == 'expl':
                    mask_ids = mask_variant.get('mask_ids_for_expl', None)
                elif mode == 'eval':
                    mask_ids = mask_variant.get('mask_ids_for_eval', None)
                else:
                    raise TypeError

            prev_subtask_weight = mask_variant.get('prev_subtask_weight', None)
            max_subtasks_to_focus_on = mask_variant.get('max_subtasks_to_focus_on', None)
            max_subtasks_per_rollout = mask_variant.get('max_subtasks_per_rollout', None)

            mode = mask_variant.get('context_post_process_mode', None)
            if mode in ['dilute_prev_subtasks_uniform', 'dilute_prev_subtasks_fixed']:
                prev_subtask_weight = 0.5

            return MaskPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                concat_context_to_obs_fn=concat_context_to_obs,
                save_env_in_snapshot=save_env_in_snapshot,
                mask_sampler=(context_distrib if mode=='expl' else eval_context_distrib),
                mask_distr=mask_distr.copy(),
                mask_ids=mask_ids,
                max_path_length=max_path_length,
                rollout_mask_order=rollout_mask_order,
                prev_subtask_weight=prev_subtask_weight,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                max_subtasks_per_rollout=max_subtasks_per_rollout,
            )
        elif contextual_mdp:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                save_env_in_snapshot=save_env_in_snapshot,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[],
                save_env_in_snapshot=save_env_in_snapshot,
            )

    expl_path_collector = create_path_collector(env, expl_policy, mode='expl')
    eval_path_collector = create_path_collector(eval_env, eval_policy, mode='eval')

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs']
    )

    algorithm.to(ptu.device)

    if variant.get("save_video", True):
        save_period = variant.get('save_video_period', 50)
        dump_video_kwargs = variant.get("dump_video_kwargs", dict())
        dump_video_kwargs['horizon'] = max_path_length

        renderer = EnvRenderer(**variant.get('renderer_kwargs', {}))

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImagesEnv(state_env, renderers={
                'image_observation' : renderer,
            })
            context_env = ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return context_env

        img_eval_env = add_images(eval_env, eval_context_distrib)

        if variant.get('log_eval_video', True):
            video_path_collector = create_path_collector(img_eval_env, eval_policy, mode='eval')
            rollout_function = video_path_collector._rollout_fn
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag="eval",
                imsize=variant['renderer_kwargs']['width'],
                image_format='CHW',
                save_video_period=save_period,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(eval_video_func)

        # additional eval videos for mask conditioned case
        if mask_conditioned:
            default_list = [
                'atomic',
                'atomic_seq',
                'cumul_seq',
                'full',
            ]
            eval_rollouts_for_videos = mask_variant.get('eval_rollouts_for_videos', default_list)
            for key in eval_rollouts_for_videos:
                assert key in default_list

            if 'cumul_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            cumul_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_cumul" if mask_conditioned else "eval",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'full' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            full=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_full",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'atomic_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            atomic_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_atomic",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

        if variant.get('log_expl_video', True) and not variant['algo_kwargs'].get('eval_only', False):
            img_expl_env = add_images(env, context_distrib)
            video_path_collector = create_path_collector(img_expl_env, expl_policy, mode='expl')
            rollout_function = video_path_collector._rollout_fn
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                expl_policy,
                tag="expl",
                imsize=variant['renderer_kwargs']['width'],
                image_format='CHW',
                save_video_period=save_period,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(expl_video_func)

    addl_collectors = []
    addl_log_prefixes = []
    if mask_conditioned and mask_variant.get('log_mask_diagnostics', True):
        default_list = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
        eval_rollouts_to_log = mask_variant.get('eval_rollouts_to_log', default_list)
        for key in eval_rollouts_to_log:
            assert key in default_list

        # atomic masks
        if 'atomic' in eval_rollouts_to_log:
            for mask_id in eval_path_collector.mask_ids:
                mask_kwargs=dict(
                    mask_ids=[mask_id],
                    mask_distr=dict(
                        atomic=1.0,
                    ),
                )
                collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
                addl_collectors.append(collector)
            addl_log_prefixes += [
                'mask_{}/'.format(''.join(str(mask_id)))
                for mask_id in eval_path_collector.mask_ids
            ]

        # full mask
        if 'full' in eval_rollouts_to_log:
            mask_kwargs=dict(
                mask_distr=dict(
                    full=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_full/')

        # cumulative, sequential mask
        if 'cumul_seq' in eval_rollouts_to_log:
            mask_kwargs=dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    cumul_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_cumul_seq/')

        # atomic, sequential mask
        if 'atomic_seq' in eval_rollouts_to_log:
            mask_kwargs=dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    atomic_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_atomic_seq/')

        def get_mask_diagnostics(unused):
            from rlkit.core.logging import append_log, add_prefix, OrderedDict
            log = OrderedDict()
            for prefix, collector in zip(addl_log_prefixes, addl_collectors):
                paths = collector.collect_new_paths(
                    max_path_length,
                    variant['algo_kwargs']['num_eval_steps_per_epoch'],
                    discard_incomplete_paths=True,
                )
                old_path_info = eval_env.get_diagnostics(paths)

                keys_to_keep = []
                for key in old_path_info.keys():
                    if ('env_infos' in key) and ('final' in key) and ('Mean' in key):
                        keys_to_keep.append(key)
                path_info = OrderedDict()
                for key in keys_to_keep:
                    path_info[key] = old_path_info[key]

                generic_info = add_prefix(
                    path_info,
                    prefix,
                )
                append_log(log, generic_info)

            for collector in addl_collectors:
                collector.end_epoch(0)
            return log

        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)
        
    if 'ckpt' in variant:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp
        assert variant['algo_kwargs'].get('eval_only', False)

        def update_networks(algo, epoch):
            if 'ckpt_epoch' in variant:
                return

            if epoch % algo._eval_epoch_freq == 0:
                filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
                print("Loading ckpt from", filename)
                data = torch.load(filename)#, map_location='cuda:1')
                eval_policy = data['evaluation/policy']
                eval_policy.to(ptu.device)
                algo.eval_data_collector._policy = eval_policy
                for collector in addl_collectors:
                    collector._policy = eval_policy

        algorithm.post_train_funcs.insert(0, update_networks)

    algorithm.train()
Esempio n. 12
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    dataset_path=None,
    policy_path=None,
    action_space_sampling=False,
    env_class=SawyerDoorEnv,
    env_kwargs=None,
    init_camera=sawyer_door_env_camera_v2,
):
    if policy_path is not None:
        filename = "/tmp/sawyer_door_pull_open_oracle+random_policy_data_closer_zoom_action_limited" + str(
            N) + ".npy"
    elif action_space_sampling:
        filename = "/tmp/sawyer_door_pull_open_zoomed_in_action_space_sampling" + str(
            N) + ".npy"
    else:
        filename = "/tmp/sawyer_door_pull_open" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        env = env_class(**env_kwargs)
        env = ImageEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        info['env'] = env
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        env.wrapped_env.reset()
        dataset = np.zeros((N, imsize * imsize * 3), dtype=np.uint8)
        for i in range(N):
            if i % 20 == 0:
                env.reset_model()
                exploration_policy.reset()
            for _ in range(10):
                action = exploration_policy.get_action()[0]
                env.wrapped_env.step(action)
            # env.set_to_goal_angle(env.get_goal()['state_desired_goal'])
            img = env._get_flat_img()
            dataset[i, :] = unormalize_image(img)
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 13
0
def generate_vae_dataset_from_params(
    env_class=None,
    env_kwargs=None,
    env_id=None,
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    num_channels=1,
    show=False,
    init_camera=None,
    dataset_path=None,
    oracle_dataset=False,
    n_random_steps=100,
    vae_dataset_specific_env_kwargs=None,
    save_file_prefix=None,
    use_linear_dynamics=False,
):
    from multiworld.core.image_env import ImageEnv, unormalize_image
    from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
    import time

    assert oracle_dataset == True

    if env_kwargs is None:
        env_kwargs = {}
    if save_file_prefix is None:
        save_file_prefix = env_id
    if save_file_prefix is None:
        save_file_prefix = env_class.__name__
    filename = "/tmp/{}_N{}_{}_imsize{}_oracle{}.npy".format(
        save_file_prefix,
        str(N),
        init_camera.__name__ if init_camera else '',
        imsize,
        oracle_dataset,
    )
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
        np.random.shuffle(dataset)
        N = dataset.shape[0]
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        np.random.shuffle(dataset)
        print("loaded data from saved file", filename)
    else:
        now = time.time()

        if env_id is not None:
            import gym
            import multiworld
            multiworld.register_all_envs()
            env = gym.make(env_id)
        else:
            if vae_dataset_specific_env_kwargs is None:
                vae_dataset_specific_env_kwargs = {}
            for key, val in env_kwargs.items():
                if key not in vae_dataset_specific_env_kwargs:
                    vae_dataset_specific_env_kwargs[key] = val
            env = env_class(**vae_dataset_specific_env_kwargs)
        if not isinstance(env, ImageEnv):
            env = ImageEnv(
                env,
                imsize,
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        setup_pickup_image_env(env, num_presampled_goals=N)
        env.reset()
        info['env'] = env

        dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8)
        for i in range(N):
            img = env._presampled_goals['image_desired_goal'][i]
            dataset[i, :] = unormalize_image(img)
            if show:
                img = img.reshape(3, imsize, imsize).transpose()
                img = img[::-1, :, ::-1]
                cv2.imshow('img', img)
                cv2.waitKey(1)
                time.sleep(.2)
                # radius = input('waiting...')
        print("done making training data", filename, time.time() - now)
        np.random.shuffle(dataset)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Esempio n. 14
0
def representation_learning_with_goal_distribution_launcher(
        max_path_length,
        contextual_replay_buffer_kwargs,
        sac_trainer_kwargs,
        algo_kwargs,
        qf_kwargs=None,
        policy_kwargs=None,
        # env settings
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='latent_observation',
        desired_goal_key='latent_desired_goal',
        achieved_goal_key='latent_achieved_goal',
        renderer_kwargs=None,
        # mask settings
        mask_variant=None,  # TODO: manually unpack this as well
        mask_conditioned=True,
        mask_format='vector',
        infer_masks=False,
        # rollout
        expl_goal_sampling_mode=None,
        eval_goal_sampling_mode=None,
        eval_rollouts_for_videos=None,
        eval_rollouts_to_log=None,
        # debugging
        log_mask_diagnostics=True,
        log_expl_video=True,
        log_eval_video=True,
        save_video=True,
        save_video_period=50,
        save_env_in_snapshot=True,
        dump_video_kwargs=None,
        # re-loading
        ckpt=None,
        ckpt_epoch=None,
        seedid=0,
):
    if eval_rollouts_to_log is None:
        eval_rollouts_to_log = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
    if renderer_kwargs is None:
        renderer_kwargs = {}
    if dump_video_kwargs is None:
        dump_video_kwargs = {}
    if eval_rollouts_for_videos is None:
        eval_rollouts_for_videos = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
    if mask_variant is None:
        mask_variant = {}
    if policy_kwargs is None:
        policy_kwargs = {}
    if qf_kwargs is None:
        qf_kwargs = {}
    context_key = desired_goal_key
    prev_subtask_weight = mask_variant.get('prev_subtask_weight', None)

    context_post_process_mode = mask_variant.get('context_post_process_mode',
                                                 None)
    if context_post_process_mode in [
        'dilute_prev_subtasks_uniform', 'dilute_prev_subtasks_fixed'
    ]:
        prev_subtask_weight = 0.5
    prev_subtasks_solved = mask_variant.get('prev_subtasks_solved', False)
    max_subtasks_to_focus_on = mask_variant.get(
        'max_subtasks_to_focus_on', None)
    max_subtasks_per_rollout = mask_variant.get(
        'max_subtasks_per_rollout', None)
    mask_groups = mask_variant.get('mask_groups', None)
    rollout_mask_order_for_expl = mask_variant.get(
        'rollout_mask_order_for_expl', 'fixed')
    rollout_mask_order_for_eval = mask_variant.get(
        'rollout_mask_order_for_eval', 'fixed')
    masks = mask_variant.get('masks', None)
    idx_masks = mask_variant.get('idx_masks', None)
    matrix_masks = mask_variant.get('matrix_masks', None)
    train_mask_distr = mask_variant.get('train_mask_distr', None)
    mask_inference_variant = mask_variant.get('mask_inference_variant', {})
    mask_reward_fn = mask_variant.get('reward_fn', default_masked_reward_fn)
    expl_mask_distr = mask_variant['expl_mask_distr']
    eval_mask_distr = mask_variant['eval_mask_distr']
    use_g_for_mean = mask_variant['use_g_for_mean']
    context_post_process_frac = mask_variant.get(
        'context_post_process_frac', 0.50)
    sample_masks_for_relabeling = mask_variant.get(
        'sample_masks_for_relabeling', True)

    if mask_conditioned:
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        assert mask_format in ['vector', 'matrix', 'distribution']
        goal_dim = env.observation_space.spaces[context_key].low.size
        if mask_format == 'vector':
            mask_keys = ['mask']
            mask_dims = [(goal_dim,)]
            context_dim = goal_dim + goal_dim
        elif mask_format == 'matrix':
            mask_keys = ['mask']
            mask_dims = [(goal_dim, goal_dim)]
            context_dim = goal_dim + (goal_dim * goal_dim)
        elif mask_format == 'distribution':
            mask_keys = ['mask_mu_w', 'mask_mu_g', 'mask_mu_mat',
                         'mask_sigma_inv']
            mask_dims = [(goal_dim,), (goal_dim,), (goal_dim, goal_dim),
                         (goal_dim, goal_dim)]
            context_dim = goal_dim + (goal_dim * goal_dim)  # mu and sigma_inv
        else:
            raise NotImplementedError

        if infer_masks:
            assert mask_format == 'distribution'
            env_kwargs_copy = copy.deepcopy(env_kwargs)
            env_kwargs_copy['lite_reset'] = True
            infer_masks_env = get_gym_env(env_id, env_class=env_class,
                                          env_kwargs=env_kwargs_copy)

            masks = infer_masks_fn(
                infer_masks_env,
                idx_masks,
                mask_inference_variant,
            )

        context_keys = [context_key] + mask_keys
    else:
        context_keys = [context_key]

    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)

        if mode == 'expl':
            goal_sampling_mode = expl_goal_sampling_mode
        elif mode == 'eval':
            goal_sampling_mode = eval_goal_sampling_mode
        else:
            goal_sampling_mode = None
        if goal_sampling_mode is not None:
            env.goal_sampling_mode = goal_sampling_mode

        if mask_conditioned:
            context_distrib = MaskedGoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_keys=mask_keys,
                mask_dims=mask_dims,
                mask_format=mask_format,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                prev_subtask_weight=prev_subtask_weight,
                masks=masks,
                idx_masks=idx_masks,
                matrix_masks=matrix_masks,
                mask_distr=train_mask_distr,
            )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    achieved_goal_key),  # observation_key
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=contextual_replay_buffer_kwargs.get(
                    'observation_keys', None),
                additional_context_keys=mask_keys,
                reward_fn=partial(
                    mask_reward_fn,
                    mask_format=mask_format,
                    use_g_for_mean=use_g_for_mean
                ),
            )
        else:
            context_distrib = GoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
            )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    achieved_goal_key),  # observation_key
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=contextual_replay_buffer_kwargs.get(
                    'observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, context_distrib, reward_fn

    env, context_distrib, reward_fn = contextual_env_distrib_and_reward(
        mode='expl')
    eval_env, eval_context_distrib, _ = contextual_env_distrib_and_reward(
        mode='eval')

    if mask_conditioned:
        obs_dim = (
                env.observation_space.spaces[observation_key].low.size
                + context_dim
        )
    else:
        obs_dim = (
                env.observation_space.spaces[observation_key].low.size
                + env.observation_space.spaces[context_key].low.size
        )

    action_dim = env.action_space.low.size

    if ckpt:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp

        if ckpt_epoch is not None:
            epoch = ckpt_epoch
            filename = local_path_from_s3_or_local_path(
                osp.join(ckpt, 'itr_%d.pkl' % epoch))
        else:
            filename = local_path_from_s3_or_local_path(
                osp.join(ckpt, 'params.pkl'))
        print("Loading ckpt from", filename)
        # data = joblib.load(filename)
        data = torch.load(filename, map_location='cuda:1')
        qf1 = data['trainer/qf1']
        qf2 = data['trainer/qf2']
        target_qf1 = data['trainer/target_qf1']
        target_qf2 = data['trainer/target_qf2']
        policy = data['trainer/policy']
        eval_policy = data['evaluation/policy']
        expl_policy = data['exploration/policy']
    else:
        def create_qf():
            return ConcatMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                **qf_kwargs
            )

        qf1 = create_qf()
        qf2 = create_qf()
        target_qf1 = create_qf()
        target_qf2 = create_qf()
        policy = TanhGaussianPolicy(
            obs_dim=obs_dim,
            action_dim=action_dim,
            **policy_kwargs
        )
        expl_policy = policy
        eval_policy = MakeDeterministic(policy)

    def context_from_obs_dict_fn(obs_dict):
        context_dict = {
            context_key: obs_dict[achieved_goal_key],  # observation_key
        }
        if mask_conditioned:
            if sample_masks_for_relabeling:
                batch_size = obs_dict[list(obs_dict.keys())[0]].shape[0]
                sampled_contexts = context_distrib.sample(batch_size)
                for mask_key in mask_keys:
                    context_dict[mask_key] = sampled_contexts[mask_key]
            else:
                for mask_key in mask_keys:
                    context_dict[mask_key] = obs_dict[mask_key]
        return context_dict

    def post_process_mask_fn(obs_dict, context_dict):
        assert mask_conditioned
        pp_context_dict = copy.deepcopy(context_dict)

        assert context_post_process_mode in [
            'prev_subtasks_solved',
            'dilute_prev_subtasks_uniform',
            'dilute_prev_subtasks_fixed',
            'atomic_to_corresp_cumul',
            None
        ]

        if context_post_process_mode in [
            'prev_subtasks_solved',
            'dilute_prev_subtasks_uniform',
            'dilute_prev_subtasks_fixed',
            'atomic_to_corresp_cumul'
        ]:
            frac = context_post_process_frac
            cumul_mask_to_indices = context_distrib.get_cumul_mask_to_indices(
                context_dict['mask']
            )
            for k in cumul_mask_to_indices:
                indices = cumul_mask_to_indices[k]
                subset = np.random.choice(len(indices),
                                          int(len(indices) * frac),
                                          replace=False)
                cumul_mask_to_indices[k] = indices[subset]
        else:
            cumul_mask_to_indices = None

        mode = context_post_process_mode
        if mode in [
            'prev_subtasks_solved', 'dilute_prev_subtasks_uniform',
            'dilute_prev_subtasks_fixed'
        ]:
            cumul_masks = list(cumul_mask_to_indices.keys())
            for i in range(1, len(cumul_masks)):
                curr_mask = cumul_masks[i]
                prev_mask = cumul_masks[i - 1]
                prev_obj_indices = np.where(np.array(prev_mask) > 0)[0]
                indices = cumul_mask_to_indices[curr_mask]
                if mode == 'prev_subtasks_solved':
                    pp_context_dict[context_key][indices][:, prev_obj_indices] = \
                        obs_dict[achieved_goal_key][indices][:,
                        prev_obj_indices]
                elif mode == 'dilute_prev_subtasks_uniform':
                    pp_context_dict['mask'][indices][:, prev_obj_indices] = \
                        np.random.uniform(
                            size=(len(indices), len(prev_obj_indices)))
                elif mode == 'dilute_prev_subtasks_fixed':
                    pp_context_dict['mask'][indices][:, prev_obj_indices] = 0.5
            indices_to_relabel = np.concatenate(
                list(cumul_mask_to_indices.values()))
            orig_masks = obs_dict['mask'][indices_to_relabel]
            atomic_mask_to_subindices = context_distrib.get_atomic_mask_to_indices(
                orig_masks)
            atomic_masks = list(atomic_mask_to_subindices.keys())
            cumul_masks = list(cumul_mask_to_indices.keys())
            for i in range(1, len(atomic_masks)):
                orig_atomic_mask = atomic_masks[i]
                relabeled_cumul_mask = cumul_masks[i]
                subindices = atomic_mask_to_subindices[orig_atomic_mask]
                pp_context_dict['mask'][indices_to_relabel][
                    subindices] = relabeled_cumul_mask

        return pp_context_dict

    # if mask_conditioned:
    #     variant['contextual_replay_buffer_kwargs']['post_process_batch_fn'] = post_process_mask_fn

    def concat_context_to_obs(batch, replay_buffer=None, obs_dict=None,
                              next_obs_dict=None, new_contexts=None):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        if mask_conditioned:
            if obs_dict is not None and new_contexts is not None:
                updated_contexts = post_process_mask_fn(obs_dict, new_contexts)
                batch.update(updated_contexts)

            if mask_format in ['vector', 'matrix']:
                assert len(mask_keys) == 1
                mask = batch[mask_keys[0]].reshape((len(context), -1))
                batch['observations'] = np.concatenate([obs, context, mask],
                                                       axis=1)
                batch['next_observations'] = np.concatenate(
                    [next_obs, context, mask], axis=1)
            elif mask_format == 'distribution':
                g = context
                mu_w = batch['mask_mu_w']
                mu_g = batch['mask_mu_g']
                mu_A = batch['mask_mu_mat']
                sigma_inv = batch['mask_sigma_inv']
                if use_g_for_mean:
                    mu_w_given_g = g
                else:
                    mu_w_given_g = mu_w + np.squeeze(
                        mu_A @ np.expand_dims(g - mu_g, axis=-1), axis=-1)
                sigma_w_given_g_inv = sigma_inv.reshape((len(context), -1))
                batch['observations'] = np.concatenate(
                    [obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
                batch['next_observations'] = np.concatenate(
                    [next_obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
            else:
                raise NotImplementedError
        else:
            batch['observations'] = np.concatenate([obs, context], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, context],
                                                        axis=1)
        return batch

    if 'observation_keys' not in contextual_replay_buffer_kwargs:
        contextual_replay_buffer_kwargs['observation_keys'] = []
    observation_keys = contextual_replay_buffer_kwargs['observation_keys']
    if observation_key not in observation_keys:
        observation_keys.append(observation_key)
    if achieved_goal_key not in observation_keys:
        observation_keys.append(achieved_goal_key)

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=env,
        context_keys=context_keys,
        context_distribution=context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **contextual_replay_buffer_kwargs
    )

    trainer = SACTrainer(
        env=env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs
    )

    def create_path_collector(
            env,
            policy,
            mode='expl',
            mask_kwargs=None,
    ):
        if mask_kwargs is None:
            mask_kwargs = {}
        assert mode in ['expl', 'eval']
        if mask_conditioned:
            if 'rollout_mask_order' in mask_kwargs:
                rollout_mask_order = mask_kwargs['rollout_mask_order']
            else:
                if mode == 'expl':
                    rollout_mask_order = rollout_mask_order_for_expl
                elif mode == 'eval':
                    rollout_mask_order = rollout_mask_order_for_eval
                else:
                    raise NotImplementedError

            if 'mask_distr' in mask_kwargs:
                mask_distr = mask_kwargs['mask_distr']
            else:
                if mode == 'expl':
                    mask_distr = expl_mask_distr
                elif mode == 'eval':
                    mask_distr = eval_mask_distr
                else:
                    raise NotImplementedError

            return MaskPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                concat_context_to_obs_fn=concat_context_to_obs,
                save_env_in_snapshot=save_env_in_snapshot,
                mask_sampler=(
                    context_distrib if mode == 'expl' else eval_context_distrib),
                mask_distr=mask_distr.copy(),
                mask_groups=mask_groups,
                max_path_length=max_path_length,
                rollout_mask_order=rollout_mask_order,
                prev_subtask_weight=prev_subtask_weight,
                prev_subtasks_solved=prev_subtasks_solved,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                max_subtasks_per_rollout=max_subtasks_per_rollout,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                save_env_in_snapshot=save_env_in_snapshot,
            )

    expl_path_collector = create_path_collector(env, expl_policy, mode='expl')
    eval_path_collector = create_path_collector(eval_env, eval_policy,
                                                mode='eval')

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )

    algorithm.to(ptu.device)

    if save_video:
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImagesEnv(state_env, renderers={
                'image_observation': renderer,
            })
            context_env = ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return context_env

        img_eval_env = add_images(eval_env, eval_context_distrib)

        if log_eval_video:
            video_path_collector = create_path_collector(img_eval_env,
                                                         eval_policy,
                                                         mode='eval')
            rollout_function = video_path_collector._rollout_fn
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag="eval",
                imsize=renderer_kwargs['width'],
                image_format='CHW',
                save_video_period=save_video_period,
                horizon=max_path_length,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(eval_video_func)

        # additional eval videos for mask conditioned case
        if mask_conditioned:
            if 'cumul_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            cumul_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_cumul" if mask_conditioned else "eval",
                    imsize=renderer_kwargs['width'],
                    image_format='HWC',
                    save_video_period=save_video_period,
                    horizon=max_path_length,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'full' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            full=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_full",
                    imsize=renderer_kwargs['width'],
                    image_format='HWC',
                    save_video_period=save_video_period,
                    horizon=max_path_length,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'atomic_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            atomic_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_atomic",
                    imsize=renderer_kwargs['width'],
                    image_format='HWC',
                    save_video_period=save_video_period,
                    horizon=max_path_length,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

        if log_expl_video:
            img_expl_env = add_images(env, context_distrib)
            video_path_collector = create_path_collector(img_expl_env,
                                                         expl_policy,
                                                         mode='expl')
            rollout_function = video_path_collector._rollout_fn
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                expl_policy,
                tag="expl",
                imsize=renderer_kwargs['width'],
                image_format='CHW',
                save_video_period=save_video_period,
                horizon=max_path_length,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(expl_video_func)

    if mask_conditioned and log_mask_diagnostics:
        collectors = []
        log_prefixes = []

        default_list = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
        for key in eval_rollouts_to_log:
            assert key in default_list

        if 'atomic' in eval_rollouts_to_log:
            num_masks = len(eval_path_collector.mask_groups)
            for mask_id in range(num_masks):
                mask_kwargs = dict(
                    rollout_mask_order=[mask_id],
                    mask_distr=dict(
                        atomic_seq=1.0,
                    ),
                )
                collector = create_path_collector(eval_env, eval_policy,
                                                  mode='eval',
                                                  mask_kwargs=mask_kwargs)
                collectors.append(collector)
            log_prefixes += [
                'mask_{}/'.format(''.join(str(mask_id)))
                for mask_id in range(num_masks)
            ]

        # full mask
        if 'full' in eval_rollouts_to_log:
            mask_kwargs = dict(
                mask_distr=dict(
                    full=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy,
                                              mode='eval',
                                              mask_kwargs=mask_kwargs)
            collectors.append(collector)
            log_prefixes.append('mask_full/')

        # cumulative, sequential mask
        if 'cumul_seq' in eval_rollouts_to_log:
            mask_kwargs = dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    cumul_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy,
                                              mode='eval',
                                              mask_kwargs=mask_kwargs)
            collectors.append(collector)
            log_prefixes.append('mask_cumul_seq/')

        # atomic, sequential mask
        if 'atomic_seq' in eval_rollouts_to_log:
            mask_kwargs = dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    atomic_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy,
                                              mode='eval',
                                              mask_kwargs=mask_kwargs)
            collectors.append(collector)
            log_prefixes.append('mask_atomic_seq/')

        def get_mask_diagnostics(unused):
            from rlkit.core.logging import append_log, add_prefix, OrderedDict
            log = OrderedDict()
            for prefix, collector in zip(log_prefixes, collectors):
                paths = collector.collect_new_paths(
                    max_path_length,
                    max_path_length,  # masking_eval_steps,
                    discard_incomplete_paths=True,
                )
                # old_path_info = eval_util.get_generic_path_information(paths)
                old_path_info = eval_env.get_diagnostics(paths)

                keys_to_keep = []
                for key in old_path_info.keys():
                    if ('env_infos' in key) and ('final' in key) and (
                            'Mean' in key):
                        keys_to_keep.append(key)
                path_info = OrderedDict()
                for key in keys_to_keep:
                    path_info[key] = old_path_info[key]

                generic_info = add_prefix(
                    path_info,
                    prefix,
                )
                append_log(log, generic_info)

            for collector in collectors:
                collector.end_epoch(0)
            return log

        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)
    algorithm.train()
Esempio n. 15
0
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    dataset_path=None,
    policy_path=None,
    ratio_oracle_policy_data_to_random=1 / 2,
    action_space_sampling=False,
    env_class=None,
    env_kwargs=None,
    action_plus_random_sampling=False,
    init_camera=sawyer_door_env_camera,
):
    if policy_path is not None:
        filename = "/tmp/sawyer_door_push_open_oracle+random_policy_data_closer_zoom_action_limited" + str(
            N) + ".npy"
    elif action_space_sampling:
        filename = "/tmp/sawyer_door_push_open_zoomed_in_action_space_sampling" + str(
            N) + ".npy"
    else:
        filename = "/tmp/sawyer_door_push_open" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    elif action_space_sampling:
        env = SawyerDoorPushOpenEnv(**env_kwargs)
        env = ImageEnv(
            env,
            imsize,
            transpose=False,
            init_camera=sawyer_door_env_camera,
            normalize=False,
        )
        action_space = Box(np.array([-env.max_x_pos, .5, .06]),
                           np.array([env.max_x_pos, env.max_y_pos, .06]))
        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            env.set_to_goal_pos(action_space.sample())  #move arm to spot
            goal = env.sample_goal()
            env.set_to_goal(goal)
            img = env.get_image().flatten()
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        info['env'] = env
    elif action_plus_random_sampling:
        env = env_class(**env_kwargs)
        env = ImageEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        action_space = Box(np.array([-env.max_x_pos, .5, .06]),
                           np.array([env.max_x_pos, .6, .06]))
        action_sampled_data = int(N / 2)
        dataset = np.zeros((N, imsize * imsize * 3))
        print('Action Space Sampling')
        for i in range(action_sampled_data):
            env.set_to_goal_pos(action_space.sample())  # move arm to spot
            goal = env.sample_goal()
            env.set_to_goal(goal)
            img = env._get_flat_img()
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        env._wrapped_env.min_y_pos = .6
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        print('Random Sampling')
        for i in range(action_sampled_data, N):
            if i % 20 == 0:
                env.reset()
                exploration_policy.reset()
            for _ in range(10):
                action = exploration_policy.get_action()[0]
                env.wrapped_env.step(action)
            img = env._get_flat_img()
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        env._wrapped_env.min_y_pos = .5
        info['env'] = env
    else:
        now = time.time()
        env = SawyerDoorPushOpenEnv(max_angle=.5)
        env = ImageEnv(
            env,
            imsize,
            transpose=True,
            init_camera=sawyer_door_env_camera,
            normalize=True,
        )
        info['env'] = env
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            if i % 100 == 0:
                env.reset()
                exploration_policy.reset()
            for _ in range(25):
                # env.wrapped_env.step(
                #     env.wrapped_env.action_space.sample()
                # )
                action = exploration_policy.get_action()[0]
                env.wrapped_env.step(action)
            goal = env.sample_goal_for_rollout()
            env.set_to_goal(goal)
            img = env.step(env.action_space.sample())[0]
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info