def get_envs(variant): from rlkit.envs.vae_wrapper import VAEWrappedEnv env = gym.make(variant['env_id']) render = variant.get('render', False) reward_params = variant.get("reward_params", dict(type='latent_distance')) init_camera = variant['init_camera'] image_env = ImageEnv( env, variant.get('image_size', 48), init_camera=init_camera, transpose=True, normalize=True, ) vae = variant['vae_model'] vae_env = VAEWrappedEnv(image_env, vae, imsize=variant.get('image_size', 48), decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, sample_from_true_prior=True) env = vae_env return env
def get_envs(variant): from multiworld.core.image_env import ImageEnv from rlkit.envs.vae_wrapper import VAEWrappedEnv from rlkit.util.io import load_local_or_remote_file render = variant.get('render', False) vae_path = variant.get("vae_path", None) reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) do_state_exp = variant.get("do_state_exp", False) presample_goals = variant.get('presample_goals', False) presample_image_goals_only = variant.get('presample_image_goals_only', False) presampled_goals_path = variant.get('presampled_goals_path', None) vae = load_local_or_remote_file( vae_path) if type(vae_path) is str else vae_path if 'env_id' in variant: import gym import multiworld multiworld.register_all_envs() env = gym.make(variant['env_id']) else: env = variant["env_class"](**variant['env_kwargs']) if not do_state_exp: if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) presampled_goals = variant['generate_goal_dataset_fctn']( env=vae_env, env_id=variant.get('env_id', None), **variant['goal_generation_kwargs']) del vae_env else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env image_env = ImageEnv(env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, **variant.get('image_env_kwargs', {})) vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, presampled_goals=presampled_goals, **variant.get('vae_wrapped_env_kwargs', {})) print("Presampling all goals only") else: vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if presample_image_goals_only: presampled_goals = variant['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **variant['goal_generation_kwargs']) image_env.set_presampled_goals(presampled_goals) print("Presampling image goals only") else: print("Not using presampled goals") env = vae_env return env
def get_envs(variant): render = variant.get('render', False) vae_path = variant.get("vae_path", None) reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) presample_goals = variant.get('presample_goals', False) presample_image_goals_only = variant.get('presample_image_goals_only', False) presampled_goals_path = variant.get('presampled_goals_path', None) vae = load_local_or_remote_file( vae_path) if type(vae_path) is str else vae_path if 'env_id' in variant: env = gym.make(variant['env_id']) else: env = variant["env_class"](**variant['env_kwargs']) if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) presampled_goals = variant['generate_goal_dataset_fctn']( env=vae_env, env_id=variant.get('env_id', None), **variant['goal_generation_kwargs']) del vae_env else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env image_env = ImageEnv(env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, **variant.get('image_env_kwargs', {})) vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, presampled_goals=presampled_goals, **variant.get('vae_wrapped_env_kwargs', {})) print("Presampling all goals only") else: vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if presample_image_goals_only: presampled_goals = variant['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **variant['goal_generation_kwargs']) image_env.set_presampled_goals(presampled_goals) print("Presampling image goals only") else: print("Not using presampled goals") env = vae_env training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") env.add_mode('eval', testing_mode) env.add_mode('train', training_mode) env.add_mode('relabeling', training_mode) # relabeling_env.disable_render() env.add_mode("video_vae", 'video_vae') env.add_mode("video_env", 'video_env') return env
def get_envs(cfgs): from roworld.core.image_env import ImageEnv from rlkit.envs.vae_wrapper import VAEWrappedEnv from rlkit.util.io import load_local_or_remote_file render = cfgs.get('render', False) reward_params = cfgs.get("reward_params", dict()) do_state_exp = cfgs.get("do_state_exp", False) # TODO vae_path = cfgs.VAE_TRAINER.get("vae_path", None) init_camera = cfgs.ENV.get("init_camera", None) presample_goals = cfgs.SKEW_FIT.get('presample_goals', False) presample_image_goals_only = cfgs.SKEW_FIT.get( 'presample_image_goals_only', False) presampled_goals_path = cfgs.SKEW_FIT.get('presampled_goals_path', None) vae = load_local_or_remote_file( vae_path) if type(vae_path) is str else vae_path if cfgs.ENV.id: import gym import roworld roworld.register_all_envs() env = gym.make(cfgs.ENV.id) else: env = cfgs.ENV.cls(**cfgs.ENV.kwargs) if not do_state_exp: if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, cfgs.ENV.imsize, init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **cfgs.get('vae_wrapped_env_kwargs', {})) presampled_goals = cfgs['generate_goal_dataset_fctn']( env=vae_env, env_id=cfgs.get('env_id', None), **cfgs['goal_generation_kwargs']) del vae_env else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env image_env = ImageEnv( env, cfgs.ENV.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, ) vae_env = VAEWrappedEnv( image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, presampled_goals=presampled_goals, sample_from_true_prior=True, ) print("Pre sampling all goals only") else: vae_env = VAEWrappedEnv( image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, goal_sampling_mode='vae_prior', sample_from_true_prior=True, ) if presample_image_goals_only: presampled_goals = cfgs['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **cfgs['goal_generation_kwargs']) image_env.set_presampled_goals(presampled_goals) print("Pre sampling image goals only") else: print("Not using presampled goals") env = vae_env return env