def eval_rollout_function(self): return create_rollout_function( tdm_rollout, init_tau=self.max_tau, cycle_tau=self.cycle_taus_for_rollout, decrement_tau=self.cycle_taus_for_rollout, observation_key=self.observation_key, desired_goal_key=self.desired_goal_key, )
def make_video(args): if args.pause: import ipdb ipdb.set_trace() data = pickle.load(open(args.file, "rb")) # joblib.load(args.file) if 'policy' in data: policy = data['policy'] elif 'evaluation/policy' in data: policy = data['evaluation/policy'] else: raise AttributeError if 'env' in data: env = data['env'] elif 'evaluation/env' in data: env = data['evaluation/env'] else: raise AttributeError if isinstance(env, RemoteRolloutEnv): env = env._wrapped_env print("Policy loaded") if args.gpu: ptu.set_gpu_mode(True) policy.to(ptu.device) else: ptu.set_gpu_mode(False) policy.to(ptu.device) if isinstance(env, VAEWrappedEnv): env.mode(args.mode) max_path_length = 100 observation_key = 'latent_observation' desired_goal_key = 'latent_desired_goal' rollout_function = rf.create_rollout_function( rf.multitask_rollout, observation_key=observation_key, desired_goal_key=desired_goal_key, ) env.mode(env._mode_map['video_env']) random_id = str(uuid.uuid4()).split('-')[0] dump_video( env, policy, 'rollouts_{}.mp4'.format(random_id), rollout_function, rows=3, columns=6, pad_length=0, pad_color=255, do_timer=True, horizon=max_path_length, dirname_to_save_images=None, subdirname="rollouts", imsize=48, )
def __init__( self, observation_key='observation', desired_goal_key='desired_goal', ): self.observation_key = observation_key self.desired_goal_key = desired_goal_key self.eval_rollout_function = rf.create_rollout_function( rf.multitask_rollout, observation_key=self.observation_key, desired_goal_key=self.desired_goal_key)
def __init__( self, observation_key=None, desired_goal_key=None, rollout_goal_params=None, ): self.observation_key = observation_key self.desired_goal_key = desired_goal_key self.rollout_goal_params = rollout_goal_params self._rollout_goal = None self.train_rollout_function = create_rollout_function( multitask_rollout, observation_key=self.observation_key, desired_goal_key=self.desired_goal_key) self.eval_rollout_function = self.train_rollout_function
def grill_her_td3_experiment(variant): env = get_envs(variant) es = get_exploration_strategy(variant, env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size qf1 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) qf2 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) policy = TanhMlpPolicy(input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs'] algo_kwargs['replay_buffer'] = replay_buffer td3_kwargs = algo_kwargs['td3_kwargs'] td3_kwargs['training_env'] = env td3_kwargs['render'] = variant["render"] her_kwargs = algo_kwargs['her_kwargs'] her_kwargs['observation_key'] = observation_key her_kwargs['desired_goal_key'] = desired_goal_key algorithm = HerTd3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) if variant.get("save_video", True): rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_epoch_funcs.append(video_func) algorithm.to(ptu.device) env.vae.to(ptu.device) algorithm.train()
def eval_rollout_function(self): return create_rollout_function( multitask_rollout, observation_key=self.observation_key, desired_goal_key=self.desired_goal_key, )
def _e2e_disentangled_experiment(max_path_length, encoder_kwargs, disentangled_qf_kwargs, qf_kwargs, twin_sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, vae_evaluation_goal_sampling_mode, vae_exploration_goal_sampling_mode, base_env_evaluation_goal_sampling_mode, base_env_exploration_goal_sampling_mode, algo_kwargs, env_id=None, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', latent_dim=2, vae_wrapped_env_kwargs=None, vae_path=None, vae_n_vae_training_kwargs=None, vectorized=False, save_video=True, save_video_kwargs=None, have_no_disentangled_encoder=False, **kwargs): if env_kwargs is None: env_kwargs = {} assert env_id or env_class if env_id: import gym import multiworld multiworld.register_all_envs() train_env = gym.make(env_id) eval_env = gym.make(env_id) else: eval_env = env_class(**env_kwargs) train_env = env_class(**env_kwargs) train_env.goal_sampling_mode = base_env_exploration_goal_sampling_mode eval_env.goal_sampling_mode = base_env_evaluation_goal_sampling_mode if vae_path: vae = load_local_or_remote_file(vae_path) else: vae = get_n_train_vae(latent_dim=latent_dim, env=eval_env, **vae_n_vae_training_kwargs) train_env = VAEWrappedEnv(train_env, vae, imsize=train_env.imsize, **vae_wrapped_env_kwargs) eval_env = VAEWrappedEnv(eval_env, vae, imsize=train_env.imsize, **vae_wrapped_env_kwargs) obs_dim = train_env.observation_space.spaces[observation_key].low.size goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size action_dim = train_env.action_space.low.size encoder = ConcatMlp(input_size=obs_dim, output_size=latent_dim, **encoder_kwargs) def make_qf(): if have_no_disentangled_encoder: return ConcatMlp( input_size=obs_dim + goal_dim + action_dim, output_size=1, **qf_kwargs, ) else: return DisentangledMlpQf(encoder=encoder, preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, vectorized=vectorized, **disentangled_qf_kwargs) qf1 = make_qf() qf2 = make_qf() target_qf1 = make_qf() target_qf2 = make_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim, action_dim=action_dim, **policy_kwargs) replay_buffer = ObsDictRelabelingBuffer( env=train_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, vectorized=vectorized, **replay_buffer_kwargs) sac_trainer = SACTrainer(env=train_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **twin_sac_trainer_kwargs) trainer = HERTrainer(sac_trainer) eval_path_collector = VAEWrappedEnvPathCollector( eval_env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=vae_evaluation_goal_sampling_mode, ) expl_path_collector = VAEWrappedEnvPathCollector( train_env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=vae_exploration_goal_sampling_mode, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=train_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs, ) algorithm.to(ptu.device) if save_video: save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True) if have_no_disentangled_encoder: def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action) add_heatmap = partial(add_heatmap_img_to_o_dict, v_function=v_function) else: def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action, return_individual_q_vals=True) add_heatmap = partial( add_heatmap_imgs_to_o_dict, v_function=v_function, vectorized=vectorized, ) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, full_o_postprocess_func=add_heatmap if save_vf_heatmap else None, ) img_keys = ['v_vals'] + [ 'v_vals_dim_{}'.format(dim) for dim in range(latent_dim) ] eval_video_func = get_save_video_function(rollout_function, eval_env, MakeDeterministic(policy), get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="eval", **save_video_kwargs) train_video_func = get_save_video_function(rollout_function, train_env, policy, get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="train", **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(train_video_func) algorithm.train()
def td3_experiment(variant): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.torch.td3.td3 import TD3 as TD3Trainer from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy # preprocess_rl_variant(variant) env = get_envs(variant) expl_env = env eval_env = env es = get_exploration_strategy(variant, env) if variant.get("use_masks", False): mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict()) expl_mask_distribution_kwargs = variant[ "expl_mask_distribution_kwargs"] expl_mask_distribution = DiscreteDistribution( **expl_mask_distribution_kwargs) expl_env = RewardMaskWrapper(env, expl_mask_distribution, **mask_wrapper_kwargs) eval_mask_distribution_kwargs = variant[ "eval_mask_distribution_kwargs"] eval_mask_distribution = DiscreteDistribution( **eval_mask_distribution_kwargs) eval_env = RewardMaskWrapper(env, eval_mask_distribution, **mask_wrapper_kwargs) env = eval_env max_path_length = variant['max_path_length'] observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = variant.get('achieved_goal_key', 'latent_achieved_goal') # achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) policy = TanhMlpPolicy(input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs']) target_qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) target_qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) target_policy = TanhMlpPolicy(input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs']) if variant.get("use_subgoal_policy", False): from rlkit.policies.timed_policy import SubgoalPolicyWrapper subgoal_policy_kwargs = variant.get('subgoal_policy_kwargs', {}) policy = SubgoalPolicyWrapper(wrapped_policy=policy, env=env, episode_length=max_path_length, **subgoal_policy_kwargs) target_policy = SubgoalPolicyWrapper(wrapped_policy=target_policy, env=env, episode_length=max_path_length, **subgoal_policy_kwargs) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, # use_masks=variant.get("use_masks", False), **variant['replay_buffer_kwargs']) trainer = TD3Trainer(policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, target_policy=target_policy, **variant['td3_trainer_kwargs']) # if variant.get("use_masks", False): # from rlkit.torch.her.her import MaskedHERTrainer # trainer = MaskedHERTrainer(trainer) # else: trainer = HERTrainer(trainer) if variant.get("do_state_exp", False): eval_path_collector = GoalConditionedPathCollector( eval_env, policy, observation_key=observation_key, desired_goal_key=desired_goal_key, # use_masks=variant.get("use_masks", False), # full_mask=True, ) expl_path_collector = GoalConditionedPathCollector( expl_env, expl_policy, observation_key=observation_key, desired_goal_key=desired_goal_key, # use_masks=variant.get("use_masks", False), ) else: eval_path_collector = VAEWrappedEnvPathCollector( env, policy, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=['evaluation_goal_sampling_mode'], ) expl_path_collector = VAEWrappedEnvPathCollector( env, expl_policy, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=['exploration_goal_sampling_mode'], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **variant['algo_kwargs']) vis_variant = variant.get('vis_kwargs', {}) vis_list = vis_variant.get('vis_list', []) if variant.get("save_video", True): if variant.get("do_state_exp", False): rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, # use_masks=variant.get("use_masks", False), # full_mask=True, # vis_list=vis_list, ) video_func = get_video_save_func( rollout_function, env, policy, variant, ) else: video_func = VideoSaveFunction( env, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.to(ptu.device) if not variant.get("do_state_exp", False): env.vae.to(ptu.device) algorithm.train()
def simulate_policy(args): data = joblib.load(args.file) if 'eval_policy' in data: policy = data['eval_policy'] elif 'policy' in data: policy = data['policy'] elif 'exploration_policy' in data: policy = data['exploration_policy'] else: raise Exception("No policy found in loaded dict. Keys: {}".format( data.keys())) max_tau = get_max_tau(args) env = data['env'] env.mode("video_env") env.decode_goals = True if hasattr(env, 'enable_render'): # some environments need to be reconfigured for visualization env.enable_render() if args.gpu: set_gpu_mode(True) policy.to(ptu.device) if hasattr(env, "vae"): env.vae.to(ptu.device) else: # make sure everything is on the CPU set_gpu_mode(False) policy.cpu() if hasattr(env, "vae"): env.vae.cpu() if args.pause: import ipdb ipdb.set_trace() if isinstance(policy, PyTorchModule): policy.train(False) ROWS = 3 COLUMNS = 6 dirname = osp.dirname(args.file) input_file_name = os.path.splitext(os.path.basename(args.file))[0] filename = osp.join(dirname, "video_{}.mp4".format(input_file_name)) rollout_function = create_rollout_function( tdm_rollout, init_tau=max_tau, observation_key='observation', desired_goal_key='desired_goal', ) paths = dump_video( env, policy, filename, rollout_function, ROWS=ROWS, COLUMNS=COLUMNS, horizon=args.H, dirname_to_save_images=dirname, subdirname="rollouts_" + input_file_name, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) logger.dump_tabular()
def tdm_twin_sac_experiment(variant): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from rlkit.state_distance.tdm_networks import ( TdmQf, TdmVf, StochasticTdmPolicy, ) from rlkit.state_distance.tdm_twin_sac import TdmTwinSAC preprocess_rl_variant(variant) env = get_envs(variant) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size) goal_dim = (env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size vectorized = 'vectorized' in env.reward_type norm_order = env.norm_order variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized variant['qf_kwargs']['vectorized'] = vectorized variant['vf_kwargs']['vectorized'] = vectorized variant['qf_kwargs']['norm_order'] = norm_order variant['vf_kwargs']['norm_order'] = norm_order qf1 = TdmQf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) qf2 = TdmQf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) vf = TdmVf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, **variant['vf_kwargs']) policy = StochasticTdmPolicy(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['policy_kwargs']) variant['replay_buffer_kwargs']['vectorized'] = vectorized replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs'] algo_kwargs['replay_buffer'] = replay_buffer base_kwargs = algo_kwargs['base_kwargs'] base_kwargs['training_env'] = env base_kwargs['render'] = variant["render"] base_kwargs['render_during_eval'] = variant["render"] tdm_kwargs = algo_kwargs['tdm_kwargs'] tdm_kwargs['observation_key'] = observation_key tdm_kwargs['desired_goal_key'] = desired_goal_key algorithm = TdmTwinSAC(env, qf1=qf1, qf2=qf2, vf=vf, policy=policy, **variant['algo_kwargs']) if variant.get("save_video", True): rollout_function = rf.create_rollout_function( rf.tdm_rollout, init_tau=algorithm._sample_max_tau_for_rollout(), decrement_tau=algorithm.cycle_taus_for_rollout, cycle_tau=algorithm.cycle_taus_for_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.to(ptu.device) if not variant.get("do_state_exp", False): env.vae.to(ptu.device) algorithm.train()
def tdm_td3_experiment(variant): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.core import logger from rlkit.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.state_distance.tdm_networks import TdmQf, TdmPolicy from rlkit.state_distance.tdm_td3 import TdmTd3 preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size) goal_dim = (env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size vectorized = 'vectorized' in env.reward_type norm_order = env.norm_order variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized variant['qf_kwargs']['vectorized'] = vectorized variant['qf_kwargs']['norm_order'] = norm_order qf1 = TdmQf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) qf2 = TdmQf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) policy = TdmPolicy(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) variant['replay_buffer_kwargs']['vectorized'] = vectorized replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs'] algo_kwargs['replay_buffer'] = replay_buffer base_kwargs = algo_kwargs['base_kwargs'] base_kwargs['training_env'] = env base_kwargs['render'] = variant["render"] base_kwargs['render_during_eval'] = variant["render"] tdm_kwargs = algo_kwargs['tdm_kwargs'] tdm_kwargs['observation_key'] = observation_key tdm_kwargs['desired_goal_key'] = desired_goal_key algorithm = TdmTd3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) algorithm.to(ptu.device) if not variant.get("do_state_exp", False): env.vae.to(ptu.device) if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) rollout_function = rf.create_rollout_function( rf.tdm_rollout, init_tau=algorithm.max_tau, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.train()
def grill_her_sac_experiment(variant): env = variant["env_class"](**variant['env_kwargs']) render = variant["render"] rdim = variant["rdim"] vae_path = variant["vae_paths"][str(rdim)] reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) if init_camera is None: camera_name = "topview" else: camera_name = None env = ImageEnv( env, 84, init_camera=init_camera, camera_name=camera_name, transpose=True, normalize=True, ) env = VAEWrappedEnv(env, vae_path, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if variant['normalize']: env = NormalizedBoxEnv(env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) vf = ConcatMlp( input_size=obs_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") testing_env = pickle.loads(pickle.dumps(env)) testing_env.mode(testing_mode) training_env = pickle.loads(pickle.dumps(env)) training_env.mode(training_mode) relabeling_env = pickle.loads(pickle.dumps(env)) relabeling_env.mode(training_mode) relabeling_env.disable_render() video_vae_env = pickle.loads(pickle.dumps(env)) video_vae_env.mode("video_vae") video_goal_env = pickle.loads(pickle.dumps(env)) video_goal_env.mode("video_env") replay_buffer = ObsDictRelabelingBuffer( env=relabeling_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_kwargs']) variant["algo_kwargs"]["replay_buffer"] = replay_buffer algorithm = HerSac(testing_env, training_env=training_env, qf=qf, vf=vf, policy=policy, render=render, render_during_eval=render, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs']) if ptu.gpu_enabled(): print("using GPU") qf.to(ptu.device) vf.to(ptu.device) policy.to(ptu.device) algorithm.to(ptu.device) for e in [testing_env, training_env, video_vae_env, video_goal_env]: e.vae.to(ptu.device) algorithm.train() if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) filename = osp.join(logdir, 'video_final_env.mp4') rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) dump_video(video_goal_env, policy, filename, rollout_function) filename = osp.join(logdir, 'video_final_vae.mp4') dump_video(video_vae_env, policy, filename, rollout_function)
def grill_her_td3_experiment(variant): env = variant["env_class"](**variant['env_kwargs']) render = variant["render"] rdim = variant["rdim"] vae_path = variant["vae_paths"][str(rdim)] reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) if init_camera is None: camera_name = "topview" else: camera_name = None env = ImageEnv( env, 84, init_camera=init_camera, camera_name=camera_name, transpose=True, normalize=True, ) env = VAEWrappedEnv(env, vae_path, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] exploration_noise = variant.get('exploration_noise', 0.1) if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=exploration_noise, min_sigma=exploration_noise, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=exploration_noise, ) else: raise Exception("Invalid type: " + exploration_type) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=hidden_sizes, ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") testing_env = pickle.loads(pickle.dumps(env)) testing_env.mode(testing_mode) training_env = pickle.loads(pickle.dumps(env)) training_env.mode(training_mode) relabeling_env = pickle.loads(pickle.dumps(env)) relabeling_env.mode(training_mode) relabeling_env.disable_render() video_vae_env = pickle.loads(pickle.dumps(env)) video_vae_env.mode("video_vae") video_goal_env = pickle.loads(pickle.dumps(env)) video_goal_env.mode("video_env") replay_buffer = ObsDictRelabelingBuffer( env=relabeling_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_kwargs']) variant["algo_kwargs"]["replay_buffer"] = replay_buffer algorithm = HerTd3(testing_env, training_env=training_env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, render=render, render_during_eval=render, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs']) if ptu.gpu_enabled(): print("using GPU") algorithm.to(ptu.device) for e in [testing_env, training_env, video_vae_env, video_goal_env]: e.vae.to(ptu.device) algorithm.train() if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) filename = osp.join(logdir, 'video_final_env.mp4') rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) dump_video(video_goal_env, policy, filename, rollout_function) filename = osp.join(logdir, 'video_final_vae.mp4') dump_video(video_vae_env, policy, filename, rollout_function)
def her_td3_experiment(variant): import gym import rlkit.torch.pytorch_util as ptu from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer from rlkit.exploration_strategies.base import \ PolicyWrappedWithExplorationStrategy from rlkit.exploration_strategies.gaussian_and_epislon import \ GaussianAndEpislonStrategy from rlkit.launchers.launcher_util import setup_logger from rlkit.samplers.data_collector import GoalConditionedPathCollector from rlkit.torch.her.her import HERTrainer from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy from rlkit.torch.td3.td3 import TD3 from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm import rlkit.samplers.rollout_functions as rf from rlkit.torch.grill.launcher import get_state_experiment_video_save_function if 'env_id' in variant: eval_env = gym.make(variant['env_id']) expl_env = gym.make(variant['env_id']) else: eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs']) eval_env = variant['env_class'](**eval_env_kwargs) expl_env = variant['env_class'](**variant['env_kwargs']) observation_key = 'state_observation' desired_goal_key = 'state_desired_goal' achieved_goal_key = desired_goal_key.replace("desired", "achieved") es = GaussianAndEpislonStrategy( action_space=expl_env.action_space, max_sigma=.2, min_sigma=.2, # constant sigma epsilon=.3, ) obs_dim = expl_env.observation_space.spaces['observation'].low.size goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size action_dim = expl_env.action_space.low.size qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim, output_size=1, **variant['qf_kwargs']) qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim, output_size=1, **variant['qf_kwargs']) target_qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim, output_size=1, **variant['qf_kwargs']) target_qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim, output_size=1, **variant['qf_kwargs']) policy = TanhMlpPolicy(input_size=obs_dim + goal_dim, output_size=action_dim, **variant['policy_kwargs']) target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim, output_size=action_dim, **variant['policy_kwargs']) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = ObsDictRelabelingBuffer( env=eval_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) trainer = TD3(policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, target_policy=target_policy, **variant['trainer_kwargs']) trainer = HERTrainer(trainer) eval_path_collector = GoalConditionedPathCollector( eval_env, policy, observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = GoalConditionedPathCollector( expl_env, expl_policy, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, **variant['algo_kwargs']) if variant.get("save_video", False): rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) video_func = get_state_experiment_video_save_function( rollout_function, eval_env, policy, variant, ) algorithm.post_epoch_funcs.append(video_func) algorithm.to(ptu.device) algorithm.train()
def _use_disentangled_encoder_distance( max_path_length, encoder_kwargs, disentangled_qf_kwargs, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, evaluation_goal_sampling_mode, exploration_goal_sampling_mode, algo_kwargs, env_id=None, env_class=None, env_kwargs=None, encoder_key_prefix='encoder', encoder_input_prefix='state', latent_dim=2, reward_mode=EncoderWrappedEnv.ENCODER_DISTANCE_REWARD, # Video parameters save_video=True, save_video_kwargs=None, save_vf_heatmap=True, **kwargs): if save_video_kwargs is None: save_video_kwargs = {} if env_kwargs is None: env_kwargs = {} assert env_id or env_class vectorized = ( reward_mode == EncoderWrappedEnv.VECTORIZED_ENCODER_DISTANCE_REWARD) if env_id: import gym import multiworld multiworld.register_all_envs() raw_train_env = gym.make(env_id) raw_eval_env = gym.make(env_id) else: raw_eval_env = env_class(**env_kwargs) raw_train_env = env_class(**env_kwargs) raw_train_env.goal_sampling_mode = exploration_goal_sampling_mode raw_eval_env.goal_sampling_mode = evaluation_goal_sampling_mode raw_obs_dim = ( raw_train_env.observation_space.spaces['state_observation'].low.size) action_dim = raw_train_env.action_space.low.size encoder = ConcatMlp(input_size=raw_obs_dim, output_size=latent_dim, **encoder_kwargs) encoder = Identity() encoder.input_size = raw_obs_dim encoder.output_size = raw_obs_dim np_encoder = EncoderFromNetwork(encoder) train_env = EncoderWrappedEnv( raw_train_env, np_encoder, encoder_input_prefix, key_prefix=encoder_key_prefix, reward_mode=reward_mode, ) eval_env = EncoderWrappedEnv( raw_eval_env, np_encoder, encoder_input_prefix, key_prefix=encoder_key_prefix, reward_mode=reward_mode, ) observation_key = '{}_observation'.format(encoder_key_prefix) desired_goal_key = '{}_desired_goal'.format(encoder_key_prefix) achieved_goal_key = '{}_achieved_goal'.format(encoder_key_prefix) obs_dim = train_env.observation_space.spaces[observation_key].low.size goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size def make_qf(): return DisentangledMlpQf(encoder=encoder, preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, vectorized=vectorized, **disentangled_qf_kwargs) qf1 = make_qf() qf2 = make_qf() target_qf1 = make_qf() target_qf2 = make_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim, action_dim=action_dim, **policy_kwargs) replay_buffer = ObsDictRelabelingBuffer( env=train_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, vectorized=vectorized, **replay_buffer_kwargs) sac_trainer = SACTrainer(env=train_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) trainer = HERTrainer(sac_trainer) eval_path_collector = GoalConditionedPathCollector( eval_env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode='env', ) expl_path_collector = GoalConditionedPathCollector( train_env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode='env', ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=train_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action, return_individual_q_vals=True) add_heatmap = partial( add_heatmap_imgs_to_o_dict, v_function=v_function, vectorized=vectorized, ) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, full_o_postprocess_func=add_heatmap if save_vf_heatmap else None, ) img_keys = ['v_vals'] + [ 'v_vals_dim_{}'.format(dim) for dim in range(latent_dim) ] eval_video_func = get_save_video_function(rollout_function, eval_env, MakeDeterministic(policy), get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="eval", **save_video_kwargs) train_video_func = get_save_video_function(rollout_function, train_env, policy, get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="train", **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(train_video_func) algorithm.train()
def td3_experiment_online_vae_exploring(variant): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.torch.her.online_vae_joint_algo import OnlineVaeHerJointAlgo from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy from rlkit.torch.td3.td3 import TD3 from rlkit.torch.vae.vae_trainer import ConvVAETrainer preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) exploring_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) exploring_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) exploring_policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'], ) exploring_exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=exploring_policy, ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) variant["algo_kwargs"]["replay_buffer"] = replay_buffer if variant.get('use_replay_buffer_goals', False): env.replay_buffer = replay_buffer env.use_replay_buffer_goals = True vae_trainer_kwargs = variant.get('vae_trainer_kwargs') t = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], vae, beta=variant['online_vae_beta'], **vae_trainer_kwargs) control_algorithm = TD3(env=env, training_env=env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) exploring_algorithm = TD3(env=env, training_env=env, qf1=exploring_qf1, qf2=exploring_qf2, policy=exploring_policy, exploration_policy=exploring_exploration_policy, **variant['algo_kwargs']) assert 'vae_training_schedule' not in variant,\ "Just put it in joint_algo_kwargs" algorithm = OnlineVaeHerJointAlgo(vae=vae, vae_trainer=t, env=env, training_env=env, policy=policy, exploration_policy=exploration_policy, replay_buffer=replay_buffer, algo1=control_algorithm, algo2=exploring_algorithm, algo1_prefix="Control_", algo2_prefix="VAE_Exploration_", observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['joint_algo_kwargs']) algorithm.to(ptu.device) vae.to(ptu.device) if variant.get("save_video", True): policy.train(False) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.train()
def _disentangled_her_twin_sac_experiment_v2( max_path_length, encoder_kwargs, disentangled_qf_kwargs, qf_kwargs, twin_sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, evaluation_goal_sampling_mode, exploration_goal_sampling_mode, algo_kwargs, save_video=True, env_id=None, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', # Video parameters latent_dim=2, save_video_kwargs=None, **kwargs ): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from rlkit.torch.networks import ConcatMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm if save_video_kwargs is None: save_video_kwargs = {} if env_kwargs is None: env_kwargs = {} assert env_id or env_class if env_id: import gym import multiworld multiworld.register_all_envs() train_env = gym.make(env_id) eval_env = gym.make(env_id) else: eval_env = env_class(**env_kwargs) train_env = env_class(**env_kwargs) obs_dim = train_env.observation_space.spaces[observation_key].low.size goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size action_dim = train_env.action_space.low.size encoder = ConcatMlp( input_size=goal_dim, output_size=latent_dim, **encoder_kwargs ) qf1 = DisentangledMlpQf( encoder=encoder, preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, **disentangled_qf_kwargs ) qf2 = DisentangledMlpQf( encoder=encoder, preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, **disentangled_qf_kwargs ) target_qf1 = DisentangledMlpQf( encoder=Detach(encoder), preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, **disentangled_qf_kwargs ) target_qf2 = DisentangledMlpQf( encoder=Detach(encoder), preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, **disentangled_qf_kwargs ) policy = TanhGaussianPolicy( obs_dim=obs_dim + goal_dim, action_dim=action_dim, **policy_kwargs ) replay_buffer = ObsDictRelabelingBuffer( env=train_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **replay_buffer_kwargs ) sac_trainer = SACTrainer( env=train_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **twin_sac_trainer_kwargs ) trainer = HERTrainer(sac_trainer) eval_path_collector = GoalConditionedPathCollector( eval_env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=evaluation_goal_sampling_mode, ) expl_path_collector = GoalConditionedPathCollector( train_env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=exploration_goal_sampling_mode, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=train_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs, ) algorithm.to(ptu.device) if save_video: save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True) def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action, return_individual_q_vals=True) add_heatmap = partial(add_heatmap_imgs_to_o_dict, v_function=v_function) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, full_o_postprocess_func=add_heatmap if save_vf_heatmap else None, ) img_keys = ['v_vals'] + [ 'v_vals_dim_{}'.format(dim) for dim in range(latent_dim) ] eval_video_func = get_save_video_function( rollout_function, eval_env, MakeDeterministic(policy), tag="eval", get_extra_imgs=partial(get_extra_imgs, img_keys=img_keys), **save_video_kwargs ) train_video_func = get_save_video_function( rollout_function, train_env, policy, tag="train", get_extra_imgs=partial(get_extra_imgs, img_keys=img_keys), **save_video_kwargs ) decoder = ConcatMlp( input_size=obs_dim, output_size=obs_dim, hidden_sizes=[128, 128], ) decoder.to(ptu.device) # algorithm.post_train_funcs.append(train_decoder(variant, encoder, decoder)) # algorithm.post_train_funcs.append(plot_encoder_function(variant, encoder)) # algorithm.post_train_funcs.append(plot_buffer_function( # save_video_period, 'state_achieved_goal')) # algorithm.post_train_funcs.append(plot_buffer_function( # save_video_period, 'state_desired_goal')) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(train_video_func) algorithm.train()
def tdm_td3_experiment_online_vae(variant): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.state_distance.tdm_networks import TdmQf, TdmPolicy from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.torch.online_vae.online_vae_tdm_td3 import OnlineVaeTdmTd3 preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) vae_trainer_kwargs = variant.get('vae_trainer_kwargs') observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size) goal_dim = (env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size vectorized = 'vectorized' in env.reward_type variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][ 'vectorized'] = vectorized norm_order = env.norm_order # variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][ # 'norm_order'] = norm_order qf1 = TdmQf(env=env, vectorized=vectorized, norm_order=norm_order, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) qf2 = TdmQf(env=env, vectorized=vectorized, norm_order=norm_order, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) policy = TdmPolicy(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs']['tdm_td3_kwargs'] td3_kwargs = algo_kwargs['td3_kwargs'] td3_kwargs['training_env'] = env tdm_kwargs = algo_kwargs['tdm_kwargs'] tdm_kwargs['observation_key'] = observation_key tdm_kwargs['desired_goal_key'] = desired_goal_key algo_kwargs["replay_buffer"] = replay_buffer t = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], vae, beta=variant['online_vae_beta'], **vae_trainer_kwargs) render = variant["render"] assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs" algorithm = OnlineVaeTdmTd3( online_vae_kwargs=dict(vae=vae, vae_trainer=t, **variant['algo_kwargs']['online_vae_kwargs']), tdm_td3_kwargs=dict(env=env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']['tdm_td3_kwargs']), ) algorithm.to(ptu.device) vae.to(ptu.device) if variant.get("save_video", True): policy.train(False) rollout_function = rf.create_rollout_function( rf.tdm_rollout, init_tau=algorithm._sample_max_tau_for_rollout(), decrement_tau=algorithm.cycle_taus_for_rollout, cycle_tau=algorithm.cycle_taus_for_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.to(ptu.device) if not variant.get("do_state_exp", False): env.vae.to(ptu.device) algorithm.train()
def her_sac_experiment( max_path_length, qf_kwargs, twin_sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, evaluation_goal_sampling_mode, exploration_goal_sampling_mode, algo_kwargs, save_video=True, env_id=None, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', # Video parameters save_video_kwargs=None, exploration_policy_kwargs=None, **kwargs ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from rlkit.torch.networks import ConcatMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm if not save_video_kwargs: save_video_kwargs = {} if env_kwargs is None: env_kwargs = {} assert env_id or env_class if env_id: import gym import multiworld multiworld.register_all_envs() train_env = gym.make(env_id) eval_env = gym.make(env_id) else: eval_env = env_class(**env_kwargs) train_env = env_class(**env_kwargs) obs_dim = ( train_env.observation_space.spaces[observation_key].low.size + train_env.observation_space.spaces[desired_goal_key].low.size ) action_dim = train_env.action_space.low.size qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs ) replay_buffer = ObsDictRelabelingBuffer( env=train_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **replay_buffer_kwargs ) trainer = SACTrainer( env=train_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **twin_sac_trainer_kwargs ) trainer = HERTrainer(trainer) eval_path_collector = GoalConditionedPathCollector( eval_env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=evaluation_goal_sampling_mode, ) exploration_policy = create_exploration_policy( train_env, policy, **exploration_policy_kwargs) expl_path_collector = GoalConditionedPathCollector( train_env, exploration_policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=exploration_goal_sampling_mode, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=train_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs ) algorithm.to(ptu.device) if save_video: rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, return_dict_obs=True, ) eval_video_func = get_save_video_function( rollout_function, eval_env, MakeDeterministic(policy), tag="eval", **save_video_kwargs ) train_video_func = get_save_video_function( rollout_function, train_env, exploration_policy, tag="expl", **save_video_kwargs ) # algorithm.post_train_funcs.append(plot_buffer_function( # save_video_period, 'state_achieved_goal')) # algorithm.post_train_funcs.append(plot_buffer_function( # save_video_period, 'state_desired_goal')) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(train_video_func) algorithm.train()
def her_td3_experiment(variant): if 'env_id' in variant: env = gym.make(variant['env_id']) else: env = variant['env_class'](**variant['env_kwargs']) observation_key = variant['observation_key'] desired_goal_key = variant['desired_goal_key'] variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key if variant.get('normalize', False): raise NotImplementedError() achieved_goal_key = desired_goal_key.replace("desired", "achieved") replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) obs_dim = env.observation_space.spaces['observation'].low.size action_dim = env.action_space.low.size goal_dim = env.observation_space.spaces['desired_goal'].low.size exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space, **variant['es_kwargs']) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, **variant['es_kwargs'], ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, **variant['es_kwargs'], ) else: raise Exception("Invalid type: " + exploration_type) qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim, output_size=1, **variant['qf_kwargs']) qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim, output_size=1, **variant['qf_kwargs']) policy = TanhMlpPolicy(input_size=obs_dim + goal_dim, output_size=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = HerTd3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, replay_buffer=replay_buffer, **variant['algo_kwargs']) if variant.get("save_video", False): rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, policy, variant, ) algorithm.post_epoch_funcs.append(video_func) algorithm.to(ptu.device) algorithm.train()
def her_td3_experiment(variant): import gym import multiworld.envs.mujoco import multiworld.envs.pygame import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy from rlkit.exploration_strategies.ou_strategy import OUStrategy from rlkit.torch.grill.launcher import get_video_save_func from rlkit.torch.her.her_td3 import HerTd3 from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy from rlkit.data_management.obs_dict_replay_buffer import ( ObsDictRelabelingBuffer) if 'env_id' in variant: env = gym.make(variant['env_id']) else: env = variant['env_class'](**variant['env_kwargs']) observation_key = variant['observation_key'] desired_goal_key = variant['desired_goal_key'] variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key if variant.get('normalize', False): raise NotImplementedError() achieved_goal_key = desired_goal_key.replace("desired", "achieved") replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) obs_dim = env.observation_space.spaces['observation'].low.size action_dim = env.action_space.low.size goal_dim = env.observation_space.spaces['desired_goal'].low.size exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space, **variant['es_kwargs']) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, **variant['es_kwargs'], ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, **variant['es_kwargs'], ) else: raise Exception("Invalid type: " + exploration_type) qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim, output_size=1, **variant['qf_kwargs']) qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim, output_size=1, **variant['qf_kwargs']) policy = TanhMlpPolicy(input_size=obs_dim + goal_dim, output_size=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = HerTd3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, replay_buffer=replay_buffer, **variant['algo_kwargs']) if variant.get("save_video", False): rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, policy, variant, ) algorithm.post_epoch_funcs.append(video_func) algorithm.to(ptu.device) algorithm.train()