def experiment(variant): rdim = variant["rdim"] vae_paths = { 2: "/home/ashvin/data/s3doodad/ashvin/vae/point2d-conv-sweep2/run0/id1/params.pkl", 4: "/home/ashvin/data/s3doodad/ashvin/vae/point2d-conv-sweep2/run0/id4/params.pkl" } vae_path = vae_paths[rdim] vae = joblib.load(vae_path) print("loaded", vae_path) if variant['multitask']: env = MultitaskImagePoint2DEnv(**variant['env_kwargs']) env = VAEWrappedEnv(env, vae, use_vae_obs=True, use_vae_reward=False, use_vae_goals=False) env = MultitaskToFlatEnv(env) # else: # env = Pusher2DEnv(**variant['env_kwargs']) if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=0.1, min_sigma=0.1, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=0.1, ) else: raise Exception("Invalid type: " + exploration_type) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[400, 300], ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[400, 300], ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=[400, 300], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = TD3(env, training_env=env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) print("use_gpu", variant["use_gpu"], bool(variant["use_gpu"])) if variant["use_gpu"]: gpu_id = variant["gpu_id"] ptu.set_gpu_mode(True) ptu.set_device(gpu_id) algorithm.to(ptu.device) env._wrapped_env.vae.to(ptu.device) algorithm.train()
def get_envs(variant): from multiworld.core.image_env import ImageEnv from railrl.envs.vae_wrappers import VAEWrappedEnv from railrl.misc.asset_loader import load_local_or_remote_file render = variant.get('render', False) vae_path = variant.get("vae_path", None) reproj_vae_path = variant.get("reproj_vae_path", None) ckpt = variant.get("ckpt", None) reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) do_state_exp = variant.get("do_state_exp", False) presample_goals = variant.get('presample_goals', False) presample_image_goals_only = variant.get('presample_image_goals_only', False) presampled_goals_path = variant.get('presampled_goals_path', None) if not do_state_exp and type(ckpt) is str: vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl')) if vae is not None: from railrl.core import logger logger.save_extra_data(vae, 'vae.pkl', mode='pickle') else: vae = None if vae is None and type(vae_path) is str: vae = load_local_or_remote_file(osp.join(vae_path, 'vae_params.pkl')) from railrl.core import logger logger.save_extra_data(vae, 'vae.pkl', mode='pickle') elif vae is None: vae = vae_path if type(vae) is str: vae = load_local_or_remote_file(vae) else: vae = vae if type(reproj_vae_path) is str: reproj_vae = load_local_or_remote_file(osp.join(reproj_vae_path, 'vae_params.pkl')) else: reproj_vae = None if 'env_id' in variant: import gym # trigger registration env = gym.make(variant['env_id']) else: env = variant["env_class"](**variant['env_kwargs']) if not do_state_exp: if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, ) vae_env = VAEWrappedEnv( image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, reproj_vae=reproj_vae, **variant.get('vae_wrapped_env_kwargs', {}) ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True presampled_goals = variant['generate_goal_dataset_fctn']( image_env=image_env, **variant['goal_generation_kwargs'] ) else: presampled_goals = load_local_or_remote_file( presampled_goals_path ).item() presampled_goals = { 'state_desired_goal': presampled_goals['next_obs_state'], 'image_desired_goal': presampled_goals['next_obs'], } image_env.set_presampled_goals(presampled_goals) vae_env.set_presampled_goals(presampled_goals) print("Presampling all goals") else: if presample_image_goals_only: presampled_goals = variant['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **variant['goal_generation_kwargs'] ) image_env.set_presampled_goals(presampled_goals) print("Presampling image goals only") else: print("Not using presampled goals") env = vae_env if not do_state_exp: training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") env.add_mode('eval', testing_mode) env.add_mode('train', training_mode) env.add_mode('relabeling', training_mode) # relabeling_env.disable_render() env.add_mode("video_vae", 'video_vae') env.add_mode("video_env", 'video_env') return env
def _disentangled_grill_her_twin_sac_experiment( max_path_length, encoder_kwargs, disentangled_qf_kwargs, qf_kwargs, twin_sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, vae_evaluation_goal_sampling_mode, vae_exploration_goal_sampling_mode, base_env_evaluation_goal_sampling_mode, base_env_exploration_goal_sampling_mode, algo_kwargs, env_id=None, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', latent_dim=2, vae_wrapped_env_kwargs=None, vae_path=None, vae_n_vae_training_kwargs=None, vectorized=False, save_video=True, save_video_kwargs=None, have_no_disentangled_encoder=False, **kwargs): if env_kwargs is None: env_kwargs = {} assert env_id or env_class if env_id: import gym import multiworld multiworld.register_all_envs() train_env = gym.make(env_id) eval_env = gym.make(env_id) else: eval_env = env_class(**env_kwargs) train_env = env_class(**env_kwargs) train_env.goal_sampling_mode = base_env_exploration_goal_sampling_mode eval_env.goal_sampling_mode = base_env_evaluation_goal_sampling_mode if vae_path: vae = load_local_or_remote_file(vae_path) else: vae = get_n_train_vae(latent_dim=latent_dim, env=eval_env, **vae_n_vae_training_kwargs) train_env = VAEWrappedEnv(train_env, vae, imsize=train_env.imsize, **vae_wrapped_env_kwargs) eval_env = VAEWrappedEnv(eval_env, vae, imsize=train_env.imsize, **vae_wrapped_env_kwargs) obs_dim = train_env.observation_space.spaces[observation_key].low.size goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size action_dim = train_env.action_space.low.size encoder = FlattenMlp(input_size=obs_dim, output_size=latent_dim, **encoder_kwargs) def make_qf(): if have_no_disentangled_encoder: return FlattenMlp( input_size=obs_dim + goal_dim + action_dim, output_size=1, **qf_kwargs, ) else: return DisentangledMlpQf(goal_processor=encoder, preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, vectorized=vectorized, **disentangled_qf_kwargs) qf1 = make_qf() qf2 = make_qf() target_qf1 = make_qf() target_qf2 = make_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim, action_dim=action_dim, **policy_kwargs) replay_buffer = ObsDictRelabelingBuffer( env=train_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, vectorized=vectorized, **replay_buffer_kwargs) sac_trainer = SACTrainer(env=train_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **twin_sac_trainer_kwargs) trainer = HERTrainer(sac_trainer) eval_path_collector = VAEWrappedEnvPathCollector( eval_env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=vae_evaluation_goal_sampling_mode, ) expl_path_collector = VAEWrappedEnvPathCollector( train_env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=vae_exploration_goal_sampling_mode, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=train_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs, ) algorithm.to(ptu.device) if save_video: save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True) if have_no_disentangled_encoder: def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action) add_heatmap = partial(add_heatmap_img_to_o_dict, v_function=v_function) else: def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action, return_individual_q_vals=True) add_heatmap = partial( add_heatmap_imgs_to_o_dict, v_function=v_function, vectorized=vectorized, ) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, full_o_postprocess_func=add_heatmap if save_vf_heatmap else None, ) img_keys = ['v_vals'] + [ 'v_vals_dim_{}'.format(dim) for dim in range(latent_dim) ] eval_video_func = get_save_video_function(rollout_function, eval_env, MakeDeterministic(policy), get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="eval", **save_video_kwargs) train_video_func = get_save_video_function(rollout_function, train_env, policy, get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="train", **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(train_video_func) algorithm.train()
def get_envs(variant): from multiworld.core.image_env import ImageEnv from railrl.envs.vae_wrappers import VAEWrappedEnv, ConditionalVAEWrappedEnv from railrl.misc.asset_loader import load_local_or_remote_file from railrl.torch.vae.conditional_conv_vae import CVAE, CDVAE, ACE, CADVAE, DeltaCVAE render = variant.get('render', False) vae_path = variant.get("vae_path", None) reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) do_state_exp = variant.get("do_state_exp", False) presample_goals = variant.get('presample_goals', False) presample_image_goals_only = variant.get('presample_image_goals_only', False) presampled_goals_path = get_presampled_goals_path( variant.get('presampled_goals_path', None)) vae = load_local_or_remote_file( vae_path) if type(vae_path) is str else vae_path if 'env_id' in variant: import gym import multiworld multiworld.register_all_envs() env = gym.make(variant['env_id']) else: env = variant["env_class"](**variant['env_kwargs']) if not do_state_exp: if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) presampled_goals = variant['generate_goal_dataset_fctn']( env=vae_env, env_id=variant.get('env_id', None), **variant['goal_generation_kwargs']) del vae_env else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env image_env = ImageEnv(env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, **variant.get('image_env_kwargs', {})) vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, presampled_goals=presampled_goals, **variant.get('vae_wrapped_env_kwargs', {})) print("Presampling all goals only") else: if type(vae) is CVAE or type(vae) is CDVAE or type( vae) is ACE or type(vae) is CADVAE or type( vae) is DeltaCVAE: vae_env = ConditionalVAEWrappedEnv( image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) else: vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if presample_image_goals_only: presampled_goals = variant['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **variant['goal_generation_kwargs']) image_env.set_presampled_goals(presampled_goals) print("Presampling image goals only") else: print("Not using presampled goals") env = vae_env return env
def grill_her_sac_experiment(variant): env = variant["env_class"](**variant['env_kwargs']) render = variant["render"] rdim = variant["rdim"] vae_path = variant["vae_paths"][str(rdim)] reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) if init_camera is None: camera_name = "topview" else: camera_name = None env = ImageEnv( env, 84, init_camera=init_camera, camera_name=camera_name, transpose=True, normalize=True, ) env = VAEWrappedEnv( env, vae_path, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {}) ) if variant['normalize']: env = NormalizedBoxEnv(env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = ( env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size ) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) vf = FlattenMlp( input_size=obs_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") testing_env = pickle.loads(pickle.dumps(env)) testing_env.mode(testing_mode) training_env = pickle.loads(pickle.dumps(env)) training_env.mode(training_mode) relabeling_env = pickle.loads(pickle.dumps(env)) relabeling_env.mode(training_mode) relabeling_env.disable_render() video_vae_env = pickle.loads(pickle.dumps(env)) video_vae_env.mode("video_vae") video_goal_env = pickle.loads(pickle.dumps(env)) video_goal_env.mode("video_env") replay_buffer = ObsDictRelabelingBuffer( env=relabeling_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_kwargs'] ) variant["algo_kwargs"]["replay_buffer"] = replay_buffer algorithm = HerSac( testing_env, training_env=training_env, qf=qf, vf=vf, policy=policy, render=render, render_during_eval=render, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs'] ) if ptu.gpu_enabled(): print("using GPU") qf.to(ptu.device) vf.to(ptu.device) policy.to(ptu.device) algorithm.to(ptu.device) for e in [testing_env, training_env, video_vae_env, video_goal_env]: e.vae.to(ptu.device) algorithm.train() if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) filename = osp.join(logdir, 'video_final_env.mp4') rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) dump_video(video_goal_env, policy, filename, rollout_function) filename = osp.join(logdir, 'video_final_vae.mp4') dump_video(video_vae_env, policy, filename, rollout_function)
def grill_her_td3_experiment(variant): env = variant["env_class"](**variant['env_kwargs']) render = variant["render"] rdim = variant["rdim"] vae_path = variant["vae_paths"][str(rdim)] reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) if init_camera is None: camera_name = "topview" else: camera_name = None env = ImageEnv( env, 84, init_camera=init_camera, camera_name=camera_name, transpose=True, normalize=True, ) env = VAEWrappedEnv( env, vae_path, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {}) ) if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] exploration_noise = variant.get('exploration_noise', 0.1) if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=exploration_noise, min_sigma=exploration_noise, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=exploration_noise, ) else: raise Exception("Invalid type: " + exploration_type) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = ( env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size ) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=hidden_sizes, ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") testing_env = pickle.loads(pickle.dumps(env)) testing_env.mode(testing_mode) training_env = pickle.loads(pickle.dumps(env)) training_env.mode(training_mode) relabeling_env = pickle.loads(pickle.dumps(env)) relabeling_env.mode(training_mode) relabeling_env.disable_render() video_vae_env = pickle.loads(pickle.dumps(env)) video_vae_env.mode("video_vae") video_goal_env = pickle.loads(pickle.dumps(env)) video_goal_env.mode("video_env") replay_buffer = ObsDictRelabelingBuffer( env=relabeling_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_kwargs'] ) variant["algo_kwargs"]["replay_buffer"] = replay_buffer algorithm = HerTd3( testing_env, training_env=training_env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, render=render, render_during_eval=render, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs'] ) if ptu.gpu_enabled(): print("using GPU") algorithm.to(ptu.device) for e in [testing_env, training_env, video_vae_env, video_goal_env]: e.vae.to(ptu.device) algorithm.train() if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) filename = osp.join(logdir, 'video_final_env.mp4') rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) dump_video(video_goal_env, policy, filename, rollout_function) filename = osp.join(logdir, 'video_final_vae.mp4') dump_video(video_vae_env, policy, filename, rollout_function)
def experiment(variant): rdim = variant["rdim"] use_env_goals = variant["use_env_goals"] vae_path = variant["vae_paths"][str(rdim)] render = variant["render"] wrap_mujoco_env = variant.get("wrap_mujoco_env", False) # vae = torch.load(vae_path) # print("loaded", vae_path) from railrl.envs.wrappers import ImageMujocoEnv, NormalizedBoxEnv from railrl.images.camera import sawyer_init_camera env = variant["env"](**variant['env_kwargs']) env = NormalizedBoxEnv(ImageMujocoEnv( env, imsize=84, keep_prev=0, init_camera=sawyer_init_camera, )) if wrap_mujoco_env: env = ImageMujocoEnv(env, 84, camera_name="topview", transpose=True, normalize=True) if use_env_goals: track_qpos_goal = variant.get("track_qpos_goal", 0) env = VAEWrappedImageGoalEnv(env, vae_path, use_vae_obs=True, use_vae_reward=True, use_vae_goals=True, render_goals=render, render_rollouts=render, track_qpos_goal=track_qpos_goal) else: env = VAEWrappedEnv(env, vae_path, use_vae_obs=True, use_vae_reward=True, use_vae_goals=True, render_goals=render, render_rollouts=render) env = MultitaskToFlatEnv(env) if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=0.1, min_sigma=0.1, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=0.1, ) else: raise Exception("Invalid type: " + exploration_type) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[400, 300], ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[400, 300], ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=[400, 300], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = TD3( env, training_env=env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs'] ) algorithm.to(ptu.device) env._wrapped_env.vae.to(ptu.device)
def train_vae(variant): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.conv_vae import ConvVAE from railrl.torch.vae.conv_vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from multiworld.core.image_env import ImageEnv from railrl.envs.vae_wrappers import VAEWrappedEnv from railrl.misc.asset_loader import local_path_from_s3_or_local_path logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) env_id = variant['generate_vae_dataset_kwargs'].get('env_id', None) if env_id is not None: import gym env = gym.make(env_id) else: env_class = variant['generate_vae_dataset_kwargs']['env_class'] env_kwargs = variant['generate_vae_dataset_kwargs']['env_kwargs'] env = env_class(**env_kwargs) representation_size = variant["representation_size"] beta = variant["beta"] if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None # obtain training and testing data dataset_path = variant['generate_vae_dataset_kwargs'].get( 'dataset_path', None) test_p = variant['generate_vae_dataset_kwargs'].get('test_p', 0.9) filename = local_path_from_s3_or_local_path(dataset_path) dataset = np.load(filename, allow_pickle=True).item() N = dataset['obs'].shape[0] n = int(N * test_p) train_data = {} test_data = {} for k in dataset.keys(): train_data[k] = dataset[k][:n, :] test_data[k] = dataset[k][n:, :] # setup vae variant['vae_kwargs']['action_dim'] = train_data['actions'].shape[1] if variant.get('vae_type', None) == "VAE-state": from railrl.torch.vae.vae import VAE input_size = train_data['obs'].shape[1] variant['vae_kwargs']['input_size'] = input_size m = VAE(representation_size, **variant['vae_kwargs']) elif variant.get('vae_type', None) == "VAE2": from railrl.torch.vae.conv_vae2 import ConvVAE2 variant['vae_kwargs']['imsize'] = variant['imsize'] m = ConvVAE2(representation_size, **variant['vae_kwargs']) else: variant['vae_kwargs']['imsize'] = variant['imsize'] m = ConvVAE(representation_size, **variant['vae_kwargs']) if ptu.gpu_enabled(): m.cuda() # setup vae trainer if variant.get('vae_type', None) == "VAE-state": from railrl.torch.vae.vae_trainer import VAETrainer t = VAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) else: t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) # visualization vis_variant = variant.get('vis_kwargs', {}) save_video = vis_variant.get('save_video', False) if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant['generate_vae_dataset_kwargs'].get('imsize'), init_camera=variant['generate_vae_dataset_kwargs'].get( 'init_camera'), transpose=True, normalize=True, ) render = variant.get('render', False) reward_params = variant.get("reward_params", dict()) vae_env = VAEWrappedEnv(image_env, m, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) vae_env.reset() vae_env.add_mode("video_env", 'video_env') vae_env.add_mode("video_vae", 'video_vae') if save_video: import railrl.samplers.rollout_functions as rf from railrl.policies.simple import RandomPolicy random_policy = RandomPolicy(vae_env.action_space) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=100, observation_key='latent_observation', desired_goal_key='latent_desired_goal', vis_list=vis_variant.get('vis_list', []), dont_terminate=True, ) dump_video_kwargs = variant.get("dump_video_kwargs", dict()) dump_video_kwargs['imsize'] = vae_env.imsize dump_video_kwargs['vis_list'] = [ 'image_observation', 'reconstr_image_observation', 'image_latent_histogram_2d', 'image_latent_histogram_mu_2d', 'image_plt', 'image_rew', 'image_rew_euclidean', 'image_rew_mahalanobis', 'image_rew_logp', 'image_rew_kl', 'image_rew_kl_rev', ] def visualization_post_processing(save_vis, save_video, epoch): vis_list = vis_variant.get('vis_list', []) if save_vis: if vae_env.vae_input_key_prefix == 'state': vae_env.dump_reconstructions(epoch, n_recon=vis_variant.get( 'n_recon', 16)) vae_env.dump_samples(epoch, n_samples=vis_variant.get('n_samples', 64)) if 'latent_representation' in vis_list: vae_env.dump_latent_plots(epoch) if any(elem in vis_list for elem in [ 'latent_histogram', 'latent_histogram_mu', 'latent_histogram_2d', 'latent_histogram_mu_2d' ]): vae_env.compute_latent_histogram() if not save_video and ('latent_histogram' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=True, use_true_prior=True) if not save_video and ('latent_histogram_mu' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=False, use_true_prior=True) if save_video and save_vis: from railrl.envs.vae_wrappers import temporary_mode from railrl.misc.video_gen import dump_video from railrl.core import logger vae_env.compute_goal_encodings() logdir = logger.get_snapshot_dir() filename = osp.join(logdir, 'video_{epoch}.mp4'.format(epoch=epoch)) variant['dump_video_kwargs']['epoch'] = epoch temporary_mode(vae_env, mode='video_env', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) if not vis_variant.get('save_video_env_only', True): filename = osp.join( logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch)) temporary_mode(vae_env, mode='video_vae', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) # train vae for epoch in range(variant['num_epochs']): #for epoch in range(2000): save_vis = (epoch % vis_variant['save_period'] == 0 or epoch == variant['num_epochs'] - 1) save_vae = (epoch % variant['snapshot_gap'] == 0 or epoch == variant['num_epochs'] - 1) t.train_epoch(epoch) '''if epoch % 500 == 0 or epoch == variant['num_epochs']-1: t.test_epoch( epoch, save_reconstruction=save_vis, save_interpolation=save_vis, save_vae=save_vae, ) if epoch % 200 == 0 or epoch == variant['num_epochs']-1: visualization_post_processing(save_video, save_video, epoch)''' t.test_epoch( epoch, save_reconstruction=save_vis, save_interpolation=save_vis, save_vae=save_vae, ) if epoch % 300 == 0 or epoch == variant['num_epochs'] - 1: visualization_post_processing(save_vis, save_video, epoch) logger.save_extra_data(m, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) print("finished --------------------!!!!!!!!!!!!!!!") return m