def tdm_td3_experiment(variant): env = variant['env_class'](**variant['env_kwargs']) tdm_normalizer = None qf1 = TdmQf(env=env, vectorized=True, tdm_normalizer=tdm_normalizer, **variant['qf_kwargs']) qf2 = TdmQf(env=env, vectorized=True, tdm_normalizer=tdm_normalizer, **variant['qf_kwargs']) policy = TdmPolicy(env=env, tdm_normalizer=tdm_normalizer, **variant['policy_kwargs']) exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=0.1, min_sigma=0.1, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=0.1, ) else: raise Exception("Invalid type: " + exploration_type) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = variant['replay_buffer_class']( env=env, **variant['replay_buffer_kwargs']) qf_criterion = variant['qf_criterion_class']() algo_kwargs = variant['algo_kwargs'] algo_kwargs['td3_kwargs']['qf_criterion'] = qf_criterion algo_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer algorithm = TdmTd3(env, qf1=qf1, qf2=qf2, replay_buffer=replay_buffer, policy=policy, exploration_policy=exploration_policy, **algo_kwargs) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): # env = NormalizedBoxEnv(Reacher7DofXyzGoalState()) env = NormalizedBoxEnv(MultitaskPoint2DEnv()) vectorized = True policy = StochasticTdmPolicy(env=env, **variant['policy_kwargs']) qf = TdmQf(env=env, vectorized=vectorized, norm_order=2, **variant['qf_kwargs']) vf = TdmVf(env=env, vectorized=vectorized, **variant['vf_kwargs']) replay_buffer_size = variant['algo_params']['base_kwargs'][ 'replay_buffer_size'] replay_buffer = HerReplayBuffer(replay_buffer_size, env) algorithm = TdmSac( env, qf, vf, variant['algo_params']['sac_kwargs'], variant['algo_params']['tdm_kwargs'], variant['algo_params']['base_kwargs'], supervised_weight=variant['algo_params']['supervised_weight'], policy=policy, replay_buffer=replay_buffer, ) if ptu.gpu_enabled(): algorithm.cuda() algorithm.train()
def experiment(variant): env = variant['env_class'](**variant['env_kwargs']) # env = NormalizedBoxEnv(env) # tdm_normalizer = TdmNormalizer( # env, # vectorized=True, # max_tau=variant['algo_kwargs']['tdm_kwargs']['max_tau'], # ) tdm_normalizer = None qf = TdmQf(env=env, vectorized=True, tdm_normalizer=tdm_normalizer, **variant['qf_kwargs']) policy = TdmPolicy(env=env, tdm_normalizer=tdm_normalizer, **variant['policy_kwargs']) es = OUStrategy(action_space=env.action_space, **variant['es_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = HerReplayBuffer(env=env, **variant['her_replay_buffer_kwargs']) qf_criterion = variant['qf_criterion_class']() ddpg_tdm_kwargs = variant['algo_kwargs'] ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer algorithm = TdmDdpg(env, qf=qf, replay_buffer=replay_buffer, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = variant['env_class'](**variant['env_kwargs']) tdm_normalizer = None qf1 = TdmQf( env=env, vectorized=True, tdm_normalizer=tdm_normalizer, **variant['qf_kwargs'] ) qf2 = TdmQf( env=env, vectorized=True, tdm_normalizer=tdm_normalizer, **variant['qf_kwargs'] ) policy = TdmPolicy( env=env, tdm_normalizer=tdm_normalizer, **variant['policy_kwargs'] ) es = OUStrategy( action_space=env.action_space, **variant['es_kwargs'] ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = HerReplayBuffer( env=env, **variant['her_replay_buffer_kwargs'] ) qf_criterion = variant['qf_criterion_class']() algo_kwargs = variant['algo_kwargs'] algo_kwargs['td3_kwargs']['qf_criterion'] = qf_criterion algo_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer algorithm = TdmTd3( env, qf1=qf1, qf2=qf2, replay_buffer=replay_buffer, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs'] ) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env_params = variant['env_params'] env = MultiTaskSawyerXYZReachingEnv(env_params) tdm_normalizer = TdmNormalizer( env, vectorized=True, max_tau=variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau'], ) qf = TdmQf( env=env, vectorized=True, hidden_sizes=[variant['hidden_sizes'], variant['hidden_sizes']], structure='norm_difference', tdm_normalizer=tdm_normalizer, ) policy = TdmPolicy( env=env, hidden_sizes=[variant['hidden_sizes'], variant['hidden_sizes']], tdm_normalizer=tdm_normalizer, ) es = OUStrategy( action_space=env.action_space, **variant['es_kwargs'] ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = HerReplayBuffer( env=env, **variant['her_replay_buffer_kwargs'] ) qf_criterion = variant['qf_criterion_class']() ddpg_tdm_kwargs = copy.deepcopy(variant['ddpg_tdm_kwargs']) ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion algorithm = TdmDdpg( env, qf=qf, replay_buffer=replay_buffer, policy=policy, exploration_policy=exploration_policy, **variant['ddpg_tdm_kwargs'] ) if ptu.gpu_enabled(): algorithm.cuda() algorithm.train()
def experiment(variant): vectorized = variant['sac_tdm_kwargs']['tdm_kwargs']['vectorized'] # env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs'])) env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs'])) max_tau = variant['sac_tdm_kwargs']['tdm_kwargs']['max_tau'] tdm_normalizer = TdmNormalizer( env, vectorized, max_tau=max_tau, **variant['tdm_normalizer_kwargs'] ) qf = TdmQf( env=env, vectorized=vectorized, tdm_normalizer=tdm_normalizer, **variant['qf_kwargs'] ) vf = TdmVf( env=env, vectorized=vectorized, tdm_normalizer=tdm_normalizer, **variant['vf_kwargs'] ) policy = StochasticTdmPolicy( env=env, tdm_normalizer=tdm_normalizer, **variant['policy_kwargs'] ) replay_buffer = HerReplayBuffer( env=env, **variant['her_replay_buffer_kwargs'] ) algorithm = TdmSac( env=env, policy=policy, qf=qf, vf=vf, replay_buffer=replay_buffer, **variant['sac_tdm_kwargs'] ) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): vectorized = variant['vectorized'] norm_order = variant['norm_order'] variant['ddpg_tdm_kwargs']['tdm_kwargs']['vectorized'] = vectorized variant['ddpg_tdm_kwargs']['tdm_kwargs']['norm_order'] = norm_order env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs'])) max_tau = variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau'] tdm_normalizer = TdmNormalizer(env, vectorized, max_tau=max_tau, **variant['tdm_normalizer_kwargs']) qf = TdmQf(env=env, vectorized=vectorized, norm_order=norm_order, tdm_normalizer=tdm_normalizer, **variant['qf_kwargs']) policy = TdmPolicy(env=env, tdm_normalizer=tdm_normalizer, **variant['policy_kwargs']) es = OUStrategy(action_space=env.action_space, **variant['es_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = HerReplayBuffer(env=env, **variant['her_replay_buffer_kwargs']) qf_criterion = variant['qf_criterion_class']( **variant['qf_criterion_kwargs']) ddpg_tdm_kwargs = variant['ddpg_tdm_kwargs'] ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer algorithm = TdmDdpg(env, qf=qf, replay_buffer=replay_buffer, policy=policy, exploration_policy=exploration_policy, **variant['ddpg_tdm_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env_params = variant['env_params'] env = MultiTaskSawyerXYZReachingEnv(**env_params) max_tau = variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau'] tdm_normalizer = TdmNormalizer( env, vectorized=True, max_tau=max_tau, ) qf = TdmQf(env=env, vectorized=True, norm_order=2, tdm_normalizer=tdm_normalizer, **variant['qf_kwargs']) policy = TdmPolicy(env=env, tdm_normalizer=tdm_normalizer, **variant['policy_kwargs']) es = OUStrategy(action_space=env.action_space, **variant['es_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = HerReplayBuffer(env=env, **variant['her_replay_buffer_kwargs']) qf_criterion = variant['qf_criterion_class']() ddpg_tdm_kwargs = variant['ddpg_tdm_kwargs'] ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer algorithm = TdmDdpg(env, qf=qf, replay_buffer=replay_buffer, policy=policy, exploration_policy=exploration_policy, **variant['ddpg_tdm_kwargs']) if ptu.gpu_enabled(): algorithm.cuda() algorithm.train()
def experiment(variant): vectorized = variant['sac_tdm_kwargs']['tdm_kwargs']['vectorized'] env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs'])) max_tau = variant['sac_tdm_kwargs']['tdm_kwargs']['max_tau'] qf = TdmQf(env, vectorized=vectorized, **variant['qf_kwargs']) tdm_normalizer = TdmNormalizer(env, vectorized, max_tau=max_tau, **variant['tdm_normalizer_kwargs']) implicit_model = TdmToImplicitModel( env, qf, tau=0, ) vf = TdmVf(env=env, vectorized=vectorized, tdm_normalizer=tdm_normalizer, **variant['vf_kwargs']) policy = StochasticTdmPolicy(env=env, tdm_normalizer=tdm_normalizer, **variant['policy_kwargs']) replay_buffer = HerReplayBuffer(env=env, **variant['her_replay_buffer_kwargs']) goal_slice = env.ob_to_goal_slice lbfgs_mpc_controller = TdmLBfgsBCMC(implicit_model, env, goal_slice=goal_slice, multitask_goal_slice=goal_slice, **variant['mpc_controller_kwargs']) state_only_mpc_controller = TdmLBfgsBStateOnlyCMC( vf, policy, env, goal_slice=goal_slice, multitask_goal_slice=goal_slice, **variant['state_only_mpc_controller_kwargs']) es = GaussianStrategy(action_space=env.action_space, **variant['es_kwargs']) if variant['explore_with'] == 'TdmLBfgsBCMC': exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=lbfgs_mpc_controller, ) variant['sac_tdm_kwargs']['base_kwargs']['exploration_policy'] = ( exploration_policy) elif variant['explore_with'] == 'TdmLBfgsBStateOnlyCMC': exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=state_only_mpc_controller, ) variant['sac_tdm_kwargs']['base_kwargs']['exploration_policy'] = ( exploration_policy) if variant['eval_with'] == 'TdmLBfgsBCMC': variant['sac_tdm_kwargs']['base_kwargs']['eval_policy'] = ( lbfgs_mpc_controller) elif variant['eval_with'] == 'TdmLBfgsBStateOnlyCMC': variant['sac_tdm_kwargs']['base_kwargs']['eval_policy'] = ( state_only_mpc_controller) algorithm = TdmSac(env=env, policy=policy, qf=qf, vf=vf, replay_buffer=replay_buffer, **variant['sac_tdm_kwargs']) algorithm.to(ptu.device) algorithm.train()
def tdm_td3_experiment(variant): import railrl.samplers.rollout_functions as rf import railrl.torch.pytorch_util as ptu from railrl.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from railrl.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy ) from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy from railrl.state_distance.tdm_td3 import TdmTd3 from railrl.state_distance.subgoal_planner import SubgoalPlanner from railrl.misc.asset_loader import local_path_from_s3_or_local_path import joblib preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") vectorized = 'vectorized' in env.reward_type variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized variant['replay_buffer_kwargs']['vectorized'] = vectorized if 'ckpt' in variant: if 'ckpt_epoch' in variant: epoch = variant['ckpt_epoch'] filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch)) else: filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'params.pkl')) print("Loading ckpt from", filename) data = joblib.load(filename) qf1 = data['qf1'] qf2 = data['qf2'] policy = data['policy'] variant['algo_kwargs']['base_kwargs']['reward_scale'] = policy.reward_scale else: obs_dim = ( env.observation_space.spaces[observation_key].low.size ) goal_dim = ( env.observation_space.spaces[desired_goal_key].low.size ) action_dim = env.action_space.low.size variant['qf_kwargs']['vectorized'] = vectorized norm_order = env.norm_order variant['qf_kwargs']['norm_order'] = norm_order env.reset() _, rew, _, _ = env.step(env.action_space.sample()) if hasattr(rew, "__len__"): variant['qf_kwargs']['output_dim'] = len(rew) qf1 = TdmQf( env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs'] ) qf2 = TdmQf( env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs'] ) policy = TdmPolicy( env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, reward_scale=variant['algo_kwargs']['base_kwargs'].get('reward_scale', 1.0), **variant['policy_kwargs'] ) eval_policy = None if variant.get('eval_policy', None) == 'SubgoalPlanner': eval_policy = SubgoalPlanner( env, qf1, policy, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, state_based=variant.get("do_state_exp", False), max_tau=variant['algo_kwargs']['tdm_kwargs']['max_tau'], reward_scale=variant['algo_kwargs']['base_kwargs'].get('reward_scale', 1.0), **variant['SubgoalPlanner_kwargs'] ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs'] ) algo_kwargs = variant['algo_kwargs'] algo_kwargs['replay_buffer'] = replay_buffer base_kwargs = algo_kwargs['base_kwargs'] base_kwargs['training_env'] = env base_kwargs['render'] = variant.get("render", False) base_kwargs['render_during_eval'] = variant.get("render_during_eval", False) tdm_kwargs = algo_kwargs['tdm_kwargs'] tdm_kwargs['observation_key'] = observation_key tdm_kwargs['desired_goal_key'] = desired_goal_key algorithm = TdmTd3( env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, eval_policy=eval_policy, **variant['algo_kwargs'] ) if variant.get("test_ckpt", False): algorithm.post_epoch_funcs.append(get_update_networks_func(variant)) vis_variant = variant.get('vis_kwargs', {}) vis_list = vis_variant.get('vis_list', []) if vis_variant.get("save_video", True): rollout_function = rf.create_rollout_function( rf.tdm_rollout, init_tau=algorithm._sample_max_tau_for_rollout(), decrement_tau=algorithm.cycle_taus_for_rollout, cycle_tau=algorithm.cycle_taus_for_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, vis_list=vis_list, dont_terminate=True, ) video_func = get_video_save_func( rollout_function, env, variant, ) algorithm.post_epoch_funcs.append(video_func) if ptu.gpu_enabled(): print("using GPU") algorithm.cuda() if not variant.get("do_state_exp", False): env.vae.cuda() env.reset() if not variant.get("do_state_exp", False): env.dump_samples(epoch=None) env.dump_reconstructions(epoch=None) env.dump_latent_plots(epoch=None) algorithm.train()
def tdm_twin_sac_experiment(variant): import railrl.samplers.rollout_functions as rf import railrl.torch.pytorch_util as ptu from railrl.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from railrl.state_distance.tdm_networks import ( TdmQf, TdmVf, StochasticTdmPolicy, ) from railrl.state_distance.tdm_twin_sac import TdmTwinSAC preprocess_rl_variant(variant) env = get_envs(variant) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size) goal_dim = (env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size vectorized = 'vectorized' in env.reward_type norm_order = env.norm_order variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized variant['qf_kwargs']['vectorized'] = vectorized variant['vf_kwargs']['vectorized'] = vectorized variant['qf_kwargs']['norm_order'] = norm_order variant['vf_kwargs']['norm_order'] = norm_order qf1 = TdmQf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) qf2 = TdmQf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) vf = TdmVf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, **variant['vf_kwargs']) policy = StochasticTdmPolicy(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['policy_kwargs']) variant['replay_buffer_kwargs']['vectorized'] = vectorized replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs'] algo_kwargs['replay_buffer'] = replay_buffer base_kwargs = algo_kwargs['base_kwargs'] base_kwargs['training_env'] = env base_kwargs['render'] = variant["render"] base_kwargs['render_during_eval'] = variant["render"] tdm_kwargs = algo_kwargs['tdm_kwargs'] tdm_kwargs['observation_key'] = observation_key tdm_kwargs['desired_goal_key'] = desired_goal_key algorithm = TdmTwinSAC(env, qf1=qf1, qf2=qf2, vf=vf, policy=policy, **variant['algo_kwargs']) if variant.get("save_video", True): rollout_function = rf.create_rollout_function( rf.tdm_rollout, init_tau=algorithm._sample_max_tau_for_rollout(), decrement_tau=algorithm.cycle_taus_for_rollout, cycle_tau=algorithm.cycle_taus_for_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.to(ptu.device) if not variant.get("do_state_exp", False): env.vae.to(ptu.device) algorithm.train()
def tdm_td3_experiment_online_vae(variant): import railrl.samplers.rollout_functions as rf import railrl.torch.pytorch_util as ptu from railrl.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from railrl.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.torch.online_vae.online_vae_tdm_td3 import OnlineVaeTdmTd3 preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) vae_trainer_kwargs = variant.get('vae_trainer_kwargs') observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size) goal_dim = (env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size vectorized = 'vectorized' in env.reward_type variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][ 'vectorized'] = vectorized norm_order = env.norm_order # variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][ # 'norm_order'] = norm_order qf1 = TdmQf(env=env, vectorized=vectorized, norm_order=norm_order, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) qf2 = TdmQf(env=env, vectorized=vectorized, norm_order=norm_order, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) policy = TdmPolicy(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs']['tdm_td3_kwargs'] td3_kwargs = algo_kwargs['td3_kwargs'] td3_kwargs['training_env'] = env tdm_kwargs = algo_kwargs['tdm_kwargs'] tdm_kwargs['observation_key'] = observation_key tdm_kwargs['desired_goal_key'] = desired_goal_key algo_kwargs["replay_buffer"] = replay_buffer t = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], vae, beta=variant['online_vae_beta'], **vae_trainer_kwargs) render = variant["render"] assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs" algorithm = OnlineVaeTdmTd3( online_vae_kwargs=dict(vae=vae, vae_trainer=t, **variant['algo_kwargs']['online_vae_kwargs']), tdm_td3_kwargs=dict(env=env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']['tdm_td3_kwargs']), ) algorithm.to(ptu.device) vae.to(ptu.device) if variant.get("save_video", True): policy.train(False) rollout_function = rf.create_rollout_function( rf.tdm_rollout, init_tau=algorithm._sample_max_tau_for_rollout(), decrement_tau=algorithm.cycle_taus_for_rollout, cycle_tau=algorithm.cycle_taus_for_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.to(ptu.device) if not variant.get("do_state_exp", False): env.vae.to(ptu.device) algorithm.train()
def tdm_td3_experiment(variant): import railrl.samplers.rollout_functions as rf import railrl.torch.pytorch_util as ptu from railrl.core import logger from railrl.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from railrl.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy from railrl.state_distance.tdm_td3 import TdmTd3 preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size) goal_dim = (env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size vectorized = 'vectorized' in env.reward_type norm_order = env.norm_order variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized variant['qf_kwargs']['vectorized'] = vectorized variant['qf_kwargs']['norm_order'] = norm_order qf1 = TdmQf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) qf2 = TdmQf(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) policy = TdmPolicy(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) variant['replay_buffer_kwargs']['vectorized'] = vectorized replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs'] algo_kwargs['replay_buffer'] = replay_buffer base_kwargs = algo_kwargs['base_kwargs'] base_kwargs['training_env'] = env base_kwargs['render'] = variant["render"] base_kwargs['render_during_eval'] = variant["render"] tdm_kwargs = algo_kwargs['tdm_kwargs'] tdm_kwargs['observation_key'] = observation_key tdm_kwargs['desired_goal_key'] = desired_goal_key algorithm = TdmTd3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) algorithm.to(ptu.device) if not variant.get("do_state_exp", False): env.vae.to(ptu.device) if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) rollout_function = rf.create_rollout_function( rf.tdm_rollout, init_tau=algorithm.max_tau, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.train()
def tdm_td3_experiment(variant): variant['env_kwargs'].update(variant['reward_params']) env = variant['env_class'](**variant['env_kwargs']) multiworld_env = variant.get('multiworld_env', True) if multiworld_env is not True: env = MultitaskEnvToSilentMultitaskEnv(env) if variant["render"]: env.pause_on_goal = True if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space, max_sigma=0.1, **variant['es_kwargs']) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=0.1, min_sigma=0.1, # Constant sigma **variant['es_kwargs'], ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=0.1, **variant['es_kwargs'], ) else: raise Exception("Invalid type: " + exploration_type) if multiworld_env is True: obs_dim = env.observation_space.spaces['observation'].low.size action_dim = env.action_space.low.size goal_dim = env.observation_space.spaces['desired_goal'].low.size else: obs_dim = action_dim = goal_dim = None vectorized = 'vectorized' in env.reward_type variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized norm_order = env.norm_order variant['algo_kwargs']['tdm_kwargs']['norm_order'] = norm_order qf1 = TdmQf(env=env, observation_dim=obs_dim, action_dim=action_dim, goal_dim=goal_dim, vectorized=vectorized, norm_order=norm_order, **variant['qf_kwargs']) qf2 = TdmQf(env=env, observation_dim=obs_dim, action_dim=action_dim, goal_dim=goal_dim, vectorized=vectorized, norm_order=norm_order, **variant['qf_kwargs']) policy = TdmPolicy(env=env, observation_dim=obs_dim, action_dim=action_dim, goal_dim=goal_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) relabeling_env = pickle.loads(pickle.dumps(env)) algo_kwargs = variant['algo_kwargs'] if multiworld_env is True: observation_key = variant.get('observation_key', 'state_observation') desired_goal_key = variant.get('desired_goal_key', 'state_desired_goal') achieved_goal_key = variant.get('achieved_goal_key', 'state_achieved_goal') replay_buffer = ObsDictRelabelingBuffer( env=relabeling_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, vectorized=vectorized, **variant['replay_buffer_kwargs']) algo_kwargs['tdm_kwargs']['observation_key'] = observation_key algo_kwargs['tdm_kwargs']['desired_goal_key'] = desired_goal_key else: replay_buffer = RelabelingReplayBuffer( env=relabeling_env, **variant['replay_buffer_kwargs']) # qf_criterion = variant['qf_criterion_class']() # algo_kwargs['td3_kwargs']['qf_criterion'] = qf_criterion algo_kwargs['td3_kwargs']['training_env'] = env if 'tau_schedule_kwargs' in variant: tau_schedule = IntPiecewiseLinearSchedule( **variant['tau_schedule_kwargs']) else: tau_schedule = None algo_kwargs['tdm_kwargs']['epoch_max_tau_schedule'] = tau_schedule algorithm = TdmTd3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, replay_buffer=replay_buffer, **variant['algo_kwargs']) if ptu.gpu_enabled(): qf1.to(ptu.device) qf2.to(ptu.device) policy.to(ptu.device) algorithm.to(ptu.device) algorithm.train()