Esempio n. 1
0
def get_envs(variant):
    from rlkit.envs.vae_wrapper import VAEWrappedEnv
    env = gym.make(variant['env_id'])
    render = variant.get('render', False)
    reward_params = variant.get("reward_params", dict(type='latent_distance'))
    init_camera = variant['init_camera']
    image_env = ImageEnv(
        env,
        variant.get('image_size', 48),
        init_camera=init_camera,
        transpose=True,
        normalize=True,
    )
    vae = variant['vae_model']
    vae_env = VAEWrappedEnv(image_env,
                            vae,
                            imsize=variant.get('image_size', 48),
                            decode_goals=render,
                            render_goals=render,
                            render_rollouts=render,
                            reward_params=reward_params,
                            sample_from_true_prior=True)
    env = vae_env
    return env
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from rlkit.envs.vae_wrapper import VAEWrappedEnv
    from rlkit.util.io import load_local_or_remote_file

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)
    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if 'env_id' in variant:
        import gym
        import multiworld
        multiworld.register_all_envs()
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                vae_env = VAEWrappedEnv(image_env,
                                        vae,
                                        imsize=image_env.imsize,
                                        decode_goals=render,
                                        render_goals=render,
                                        render_rollouts=render,
                                        reward_params=reward_params,
                                        **variant.get('vae_wrapped_env_kwargs',
                                                      {}))
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    env=vae_env,
                    env_id=variant.get('env_id', None),
                    **variant['goal_generation_kwargs'])
                del vae_env
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path).item()
            del image_env
            image_env = ImageEnv(env,
                                 variant.get('imsize'),
                                 init_camera=init_camera,
                                 transpose=True,
                                 normalize=True,
                                 presampled_goals=presampled_goals,
                                 **variant.get('image_env_kwargs', {}))
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    presampled_goals=presampled_goals,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            print("Presampling all goals only")
        else:
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs'])
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    return env
Esempio n. 3
0
def get_envs(variant):

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])

    if isinstance(env, ImageEnv):
        image_env = env
    else:
        image_env = ImageEnv(
            env,
            variant.get('imsize'),
            init_camera=init_camera,
            transpose=True,
            normalize=True,
        )
    if presample_goals:
        """
        This will fail for online-parallel as presampled_goals will not be
        serialized. Also don't use this for online-vae.
        """
        if presampled_goals_path is None:
            image_env.non_presampled_goal_img_is_garbage = True
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            presampled_goals = variant['generate_goal_dataset_fctn'](
                env=vae_env,
                env_id=variant.get('env_id', None),
                **variant['goal_generation_kwargs'])
            del vae_env
        else:
            presampled_goals = load_local_or_remote_file(
                presampled_goals_path).item()
        del image_env
        image_env = ImageEnv(env,
                             variant.get('imsize'),
                             init_camera=init_camera,
                             transpose=True,
                             normalize=True,
                             presampled_goals=presampled_goals,
                             **variant.get('image_env_kwargs', {}))
        vae_env = VAEWrappedEnv(image_env,
                                vae,
                                imsize=image_env.imsize,
                                decode_goals=render,
                                render_goals=render,
                                render_rollouts=render,
                                reward_params=reward_params,
                                presampled_goals=presampled_goals,
                                **variant.get('vae_wrapped_env_kwargs', {}))
        print("Presampling all goals only")
    else:
        vae_env = VAEWrappedEnv(image_env,
                                vae,
                                imsize=image_env.imsize,
                                decode_goals=render,
                                render_goals=render,
                                render_rollouts=render,
                                reward_params=reward_params,
                                **variant.get('vae_wrapped_env_kwargs', {}))
        if presample_image_goals_only:
            presampled_goals = variant['generate_goal_dataset_fctn'](
                image_env=vae_env.wrapped_env,
                **variant['goal_generation_kwargs'])
            image_env.set_presampled_goals(presampled_goals)
            print("Presampling image goals only")
        else:
            print("Not using presampled goals")

    env = vae_env

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")
    env.add_mode('eval', testing_mode)
    env.add_mode('train', training_mode)
    env.add_mode('relabeling', training_mode)
    # relabeling_env.disable_render()
    env.add_mode("video_vae", 'video_vae')
    env.add_mode("video_env", 'video_env')
    return env
Esempio n. 4
0
def get_envs(cfgs):
    from roworld.core.image_env import ImageEnv
    from rlkit.envs.vae_wrapper import VAEWrappedEnv
    from rlkit.util.io import load_local_or_remote_file

    render = cfgs.get('render', False)
    reward_params = cfgs.get("reward_params", dict())
    do_state_exp = cfgs.get("do_state_exp", False)  # TODO

    vae_path = cfgs.VAE_TRAINER.get("vae_path", None)
    init_camera = cfgs.ENV.get("init_camera", None)

    presample_goals = cfgs.SKEW_FIT.get('presample_goals', False)
    presample_image_goals_only = cfgs.SKEW_FIT.get(
        'presample_image_goals_only', False)
    presampled_goals_path = cfgs.SKEW_FIT.get('presampled_goals_path', None)

    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if cfgs.ENV.id:
        import gym
        import roworld
        roworld.register_all_envs()
        env = gym.make(cfgs.ENV.id)
    else:
        env = cfgs.ENV.cls(**cfgs.ENV.kwargs)

    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                cfgs.ENV.imsize,
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                vae_env = VAEWrappedEnv(image_env,
                                        vae,
                                        imsize=image_env.imsize,
                                        decode_goals=render,
                                        render_goals=render,
                                        render_rollouts=render,
                                        reward_params=reward_params,
                                        **cfgs.get('vae_wrapped_env_kwargs',
                                                   {}))
                presampled_goals = cfgs['generate_goal_dataset_fctn'](
                    env=vae_env,
                    env_id=cfgs.get('env_id', None),
                    **cfgs['goal_generation_kwargs'])
                del vae_env
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path).item()
            del image_env
            image_env = ImageEnv(
                env,
                cfgs.ENV.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
                presampled_goals=presampled_goals,
            )
            vae_env = VAEWrappedEnv(
                image_env,
                vae,
                imsize=image_env.imsize,
                decode_goals=render,
                render_goals=render,
                render_rollouts=render,
                reward_params=reward_params,
                presampled_goals=presampled_goals,
                sample_from_true_prior=True,
            )
            print("Pre sampling all goals only")
        else:
            vae_env = VAEWrappedEnv(
                image_env,
                vae,
                imsize=image_env.imsize,
                decode_goals=render,
                render_goals=render,
                render_rollouts=render,
                reward_params=reward_params,
                goal_sampling_mode='vae_prior',
                sample_from_true_prior=True,
            )
            if presample_image_goals_only:
                presampled_goals = cfgs['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **cfgs['goal_generation_kwargs'])
                image_env.set_presampled_goals(presampled_goals)
                print("Pre sampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env
    return env