def experiment(variant): env = NormalizedBoxEnv(variant['env_class']()) es = GaussianStrategy(action_space=env.action_space, **variant['es_kwargs']) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) policy = TanhMlpPolicy(input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = TD3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = SawyerXYZEnv(**variant['env_kwargs']) if variant['normalize']: env = NormalizedBoxEnv(env) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size goal_dim = env.goal_dim qf = ConcatMlp(input_size=obs_dim + action_dim + goal_dim, output_size=1, **variant['qf_kwargs']) vf = ConcatMlp(input_size=obs_dim + goal_dim, output_size=1, **variant['vf_kwargs']) policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim, action_dim=action_dim, **variant['policy_kwargs']) replay_buffer = SimpleHerReplayBuffer(env=env, **variant['replay_buffer_kwargs']) algorithm = HerSac(env=env, policy=policy, qf=qf, vf=vf, replay_buffer=replay_buffer, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = SawyerXYZEnv(**variant['env_kwargs']) env = MultitaskToFlatEnv(env) if variant['normalize']: env = NormalizedBoxEnv(env) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'] ) vf = ConcatMlp( input_size=obs_dim, output_size=1, **variant['vf_kwargs'] ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, **variant['policy_kwargs'] ) algorithm = SoftActorCritic( env=env, policy=policy, qf=qf, vf=vf, **variant['algo_kwargs'] ) algorithm.to(ptu.device) algorithm.train()
def get_td3_trainer(env, hidden_sizes=[256, 256], **kwargs): obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) target_qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) target_qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) policy = TanhMlpPolicy(input_size=obs_dim, output_size=action_dim, hidden_sizes=hidden_sizes) target_policy = TanhMlpPolicy(input_size=obs_dim, output_size=action_dim, hidden_sizes=hidden_sizes) trainer = TD3Trainer(policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, target_policy=target_policy, hidden_sizes=hidden_sizes) return trainer
def experiment(variant): env = Point2DEnv(**variant['env_kwargs']) env = FlatGoalEnv(env) env = NormalizedBoxEnv(env) action_dim = int(np.prod(env.action_space.shape)) obs_dim = int(np.prod(env.observation_space.shape)) qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) target_qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) target_qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **variant['policy_kwargs']) eval_env = expl_env = env eval_policy = MakeDeterministic(policy) eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) expl_path_collector = MdpPathCollector( expl_env, policy, ) replay_buffer = EnvReplayBuffer( variant['replay_buffer_size'], expl_env, ) trainer = TwinSACTrainer(env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['trainer_kwargs']) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, data_buffer=replay_buffer, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): if variant['multitask']: env = CylinderXYPusher2DEnv(**variant['env_kwargs']) env = MultitaskToFlatEnv(env) else: env = Pusher2DEnv(**variant['env_kwargs']) if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=0.1, min_sigma=0.1, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=0.1, ) else: raise Exception("Invalid type: " + exploration_type) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[400, 300], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[400, 300], ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=[400, 300], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = TD3( env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs'] ) algorithm.to(ptu.device) algorithm.train()
def td3_experiment(variant): env = variant['env_class'](**variant['env_kwargs']) env = MultitaskToFlatEnv(env) if variant.get('make_silent_env', True): env = MultitaskEnvToSilentMultitaskEnv(env) if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=0.1, min_sigma=0.1, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=0.1, ) else: raise Exception("Invalid type: " + exploration_type) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'] ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'] ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'] ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) algorithm = TD3( env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs'] ) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): expl_env = gym.make("GoalGridworld-v0") eval_env = gym.make("GoalGridworld-v0") obs_dim = expl_env.observation_space.spaces["observation"].low.size goal_dim = expl_env.observation_space.spaces["desired_goal"].low.size action_dim = expl_env.action_space.n qf = ConcatMlp( input_size=obs_dim + goal_dim, output_size=action_dim, hidden_sizes=[400, 300], ) target_qf = ConcatMlp( input_size=obs_dim + goal_dim, output_size=action_dim, hidden_sizes=[400, 300], ) eval_policy = ArgmaxDiscretePolicy(qf) exploration_strategy = EpsilonGreedy(action_space=expl_env.action_space, ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=exploration_strategy, policy=eval_policy, ) replay_buffer = ObsDictRelabelingBuffer(env=eval_env, **variant["replay_buffer_kwargs"]) observation_key = "observation" desired_goal_key = "desired_goal" eval_path_collector = GoalConditionedPathCollector( eval_env, eval_policy, observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = GoalConditionedPathCollector( expl_env, expl_policy, observation_key=observation_key, desired_goal_key=desired_goal_key, ) trainer = DQNTrainer(qf=qf, target_qf=target_qf, **variant["trainer_kwargs"]) trainer = HERTrainer(trainer) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, **variant["algo_kwargs"]) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = variant['env_class'](**variant['env_kwargs']) if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=0.1, min_sigma=0.1, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=0.1, ) else: raise Exception("Invalid type: " + exploration_type) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size goal_dim = env.goal_dim qf1 = ConcatMlp( input_size=obs_dim + action_dim + goal_dim, output_size=1, hidden_sizes=[400, 300], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim + goal_dim, output_size=1, hidden_sizes=[400, 300], ) policy = TanhMlpPolicy( input_size=obs_dim + goal_dim, output_size=action_dim, hidden_sizes=[400, 300], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = variant['replay_buffer_class']( env=env, **variant['replay_buffer_kwargs']) algorithm = HerTd3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, replay_buffer=replay_buffer, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): # env = NormalizedBoxEnv(MultiGoalEnv( # actuation_cost_coeff=10, # distance_cost_coeff=1, # goal_reward=10, # )) env = NormalizedBoxEnv(HalfCheetahEnv()) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) # qf = ExpectableQF( # obs_dim=obs_dim, # action_dim=action_dim, # hidden_size=100, # ) net_size = variant['net_size'] qf = ConcatMlp( hidden_sizes=[net_size, net_size], input_size=obs_dim + action_dim, output_size=1, ) vf = ConcatMlp( hidden_sizes=[net_size, net_size], input_size=obs_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size], obs_dim=obs_dim, action_dim=action_dim, ) # TODO(vitchyr): just creating the plotter crashes EC2 # plotter = QFPolicyPlotter( # qf=qf, # policy=policy, # obs_lst=np.array([[-2.5, 0.0], # [0.0, 0.0], # [2.5, 2.5]]), # default_action=[np.nan, np.nan], # n_samples=100 # ) algorithm = ExpectedSAC( env=env, policy=policy, qf=qf, vf=vf, # plotter=plotter, # render_eval_paths=True, **variant['algo_params']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = NormalizedBoxEnv( MultiGoalEnv( actuation_cost_coeff=10, distance_cost_coeff=1, goal_reward=10, )) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) qf = ConcatMlp( hidden_sizes=[100, 100], input_size=obs_dim + action_dim, output_size=1, ) vf = ConcatMlp( hidden_sizes=[100, 100], input_size=obs_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[100, 100], obs_dim=obs_dim, action_dim=action_dim, ) plotter = QFPolicyPlotter(qf=qf, policy=policy, obs_lst=np.array([[-2.5, 0.0], [0.0, 0.0], [2.5, 2.5]]), default_action=[np.nan, np.nan], n_samples=100) algorithm = SoftActorCritic( env=env, policy=policy, qf=qf, vf=vf, # plotter=plotter, # render_eval_paths=True, **variant['algo_params']) algorithm.to(ptu.device) algorithm.train()
def create_qf(): cnn = BasicCNN(input_width=img_width, input_height=img_height, input_channels=img_num_channels, **cnn_kwargs) joint_cnn = ApplyConvToStateAndGoalImage(cnn) return basic.MultiInputSequential( ApplyToObs(joint_cnn), basic.FlattenEachParallel(), ConcatMlp(input_size=joint_cnn.output_size + action_dim, output_size=1, **qf_kwargs))
def experiment(variant): env = CylinderXYPusher2DEnv(**variant['env_kwargs']) if variant['normalize']: env = NormalizedBoxEnv(env) es = EpsilonGreedy( action_space=env.action_space, prob_random_action=0.1, ) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size goal_dim = env.goal_dim qf1 = ConcatMlp( input_size=obs_dim + action_dim + goal_dim, output_size=1, hidden_sizes=[400, 300], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim + goal_dim, output_size=1, hidden_sizes=[400, 300], ) policy = TanhMlpPolicy( input_size=obs_dim + goal_dim, output_size=action_dim, hidden_sizes=[400, 300], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = SimpleHerReplayBuffer(env=env, **variant['replay_buffer_kwargs']) algorithm = HerTd3(env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, replay_buffer=replay_buffer, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def get_sac_trainer(env, hidden_sizes=[256, 256], reward_scale=1): obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, discount=0.99, soft_target_tau=5e-3, target_update_period=1, policy_lr=3E-4, qf_lr=3E-4, reward_scale=reward_scale, use_automatic_entropy_tuning=True) return trainer
def experiment(variant): env = variant['env_class']() env = NormalizedBoxEnv(env) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) qf = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) vf = ConcatMlp(input_size=obs_dim, output_size=1, **variant['vf_kwargs']) policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **variant['policy_kwargs']) algorithm = SoftActorCritic(env=env, policy=policy, qf=qf, vf=vf, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = NormalizedBoxEnv(variant['env_class']()) obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs']) vf = ConcatMlp(input_size=obs_dim, output_size=1, **variant['vf_kwargs']) policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **variant['policy_kwargs']) algorithm = TwinSAC(env, policy=policy, qf1=qf1, qf2=qf2, vf=vf, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def make_qf(): if have_no_disentangled_encoder: return ConcatMlp( input_size=obs_dim + goal_dim + action_dim, output_size=1, **qf_kwargs, ) else: return DisentangledMlpQf(encoder=encoder, preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, vectorized=vectorized, **disentangled_qf_kwargs)
def experiment(variant): # env = normalize(GymEnv( # 'HalfCheetah-v1', # force_reset=True, # record_video=False, # record_log=False, # )) env = NormalizedBoxEnv(gym.make('HalfCheetah-v1')) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) net_size = variant['net_size'] qf = ConcatMlp( hidden_sizes=[net_size, net_size], input_size=obs_dim + action_dim, output_size=1, ) vf = ConcatMlp( hidden_sizes=[net_size, net_size], input_size=obs_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size], obs_dim=obs_dim, action_dim=action_dim, ) algorithm = SoftActorCritic( env=env, policy=policy, qf=qf, vf=vf, **variant['algo_params'] ) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = variant['env_class'](**variant['env_kwargs']) action_dim = env.action_space.low.size obs_dim = env.observation_space.low.size qf = ConcatMlp(input_size=action_dim + obs_dim, output_size=1, **variant['qf_kwargs']) policy = TanhMlpPolicy(input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs']) algorithm = FiniteHorizonDDPG(env, qf, policy, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): env = NormalizedBoxEnv(MultiGoalEnv( actuation_cost_coeff=10, distance_cost_coeff=1, goal_reward=10, )) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) qf = ConcatMlp( hidden_sizes=[100, 100], input_size=obs_dim + action_dim, output_size=1, ) vf = ConcatMlp( hidden_sizes=[100, 100], input_size=obs_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[100, 100], obs_dim=obs_dim, action_dim=action_dim, ) algorithm = SoftActorCritic( env=env, policy=policy, qf=qf, vf=vf, **variant['algo_params'] ) algorithm.to(ptu.device) with torch.autograd.profiler.profile() as prof: algorithm.train() prof.export_chrome_trace("tmp-torch-chrome-trace.prof")
def experiment(variant): eval_env = NormalizedBoxEnv(HalfCheetahEnv()) expl_env = NormalizedBoxEnv(HalfCheetahEnv()) # Or for a specific version: # import gym # env = NormalizedBoxEnv(gym.make('HalfCheetah-v1')) obs_dim = eval_env.observation_space.low.size action_dim = eval_env.action_space.low.size qf = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'] ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'] ) target_qf = copy.deepcopy(qf) target_policy = copy.deepcopy(policy) eval_path_collector = MdpPathCollector(eval_env, policy) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=OUStrategy(action_space=expl_env.action_space), policy=policy, ) expl_path_collector = MdpPathCollector(expl_env, exploration_policy) replay_buffer = EnvReplayBuffer(variant['replay_buffer_size'], expl_env) trainer = DDPGTrainer( qf=qf, target_qf=target_qf, policy=policy, target_policy=target_policy, **variant['trainer_kwargs'] ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, **variant['algorithm_kwargs'] ) algorithm.to(ptu.device) algorithm.train()
def get_ddpg_trainer(env, hidden_sizes=[256, 256]): obs_dim = env.observation_space.low.size action_dim = env.action_space.low.size qf = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) policy = TanhMlpPolicy(input_size=obs_dim, output_size=action_dim, hidden_sizes=hidden_sizes) target_qf = copy.deepcopy(qf) target_policy = copy.deepcopy(policy) trainer = DDPGTrainer(qf=qf, target_qf=target_qf, policy=policy, target_policy=target_policy, use_soft_update=True, tau=1e-2, discount=0.99, qf_learning_rate=1e-3, policy_learning_rate=1e-4) return trainer
def her_sac_experiment(variant): env = variant['env_class'](**variant['env_kwargs']) observation_key = variant.get('observation_key', 'observation') desired_goal_key = variant.get('desired_goal_key', 'desired_goal') replay_buffer = ObsDictRelabelingBuffer(env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['replay_buffer_kwargs']) obs_dim = env.observation_space.spaces['observation'].low.size action_dim = env.action_space.low.size goal_dim = env.observation_space.spaces['desired_goal'].low.size if variant['normalize']: env = NormalizedBoxEnv(env) qf = ConcatMlp(input_size=obs_dim + action_dim + goal_dim, output_size=1, **variant['qf_kwargs']) vf = ConcatMlp(input_size=obs_dim + goal_dim, output_size=1, **variant['vf_kwargs']) policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim, action_dim=action_dim, **variant['policy_kwargs']) algorithm = HerSac(env, qf=qf, vf=vf, policy=policy, replay_buffer=replay_buffer, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs']) if ptu.gpu_enabled(): qf.to(ptu.device) vf.to(ptu.device) policy.to(ptu.device) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): expl_env = make_env() eval_env = make_env() obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size M = variant['layer_size'] qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=[M, M], ) eval_policy = MakeDeterministic(policy) eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) expl_path_collector = MdpPathCollector( expl_env, policy, ) replay_buffer = EnvReplayBuffer( variant['replay_buffer_size'], expl_env, ) trainer = SACTrainer(env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['trainer_kwargs']) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, **variant['algorithm_kwargs']) algorithm.to(ptu.device) algorithm.train()
def active_representation_learning_experiment(variant): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.obs_dict_replay_buffer import ObsDictReplayBuffer from rlkit.torch.networks import ConcatMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.arl.active_representation_learning_algorithm import \ ActiveRepresentationLearningAlgorithm from rlkit.torch.arl.representation_wrappers import RepresentationWrappedEnv from multiworld.core.image_env import ImageEnv from rlkit.samplers.data_collector import MdpPathCollector preprocess_rl_variant(variant) model_class = variant.get('model_class') model_kwargs = variant.get('model_kwargs') model = model_class(**model_kwargs) model.representation_size = 4 model.imsize = 48 variant["vae_path"] = model reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) env = variant["env_class"](**variant['env_kwargs']) image_env = ImageEnv( env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, ) env = RepresentationWrappedEnv( image_env, model, ) uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None) if uniform_dataset_fn: uniform_dataset = uniform_dataset_fn( **variant['generate_uniform_dataset_kwargs']) else: uniform_dataset = None observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = env.observation_space.spaces[observation_key].low.size action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) vae = env.vae replay_buffer = ObsDictReplayBuffer(env=env, **variant['replay_buffer_kwargs']) model_trainer_class = variant.get('model_trainer_class') model_trainer_kwargs = variant.get('model_trainer_kwargs') model_trainer = model_trainer_class( model, **model_trainer_kwargs, ) # vae_trainer = ConvVAETrainer( # env.vae, # **variant['online_vae_trainer_kwargs'] # ) assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs" max_path_length = variant['max_path_length'] trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['twin_sac_trainer_kwargs']) # trainer = HERTrainer(trainer) eval_path_collector = MdpPathCollector( env, MakeDeterministic(policy), # max_path_length, # observation_key=observation_key, # desired_goal_key=desired_goal_key, ) expl_path_collector = MdpPathCollector( env, policy, # max_path_length, # observation_key=observation_key, # desired_goal_key=desired_goal_key, ) algorithm = ActiveRepresentationLearningAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, model=model, model_trainer=model_trainer, uniform_dataset=uniform_dataset, max_path_length=max_path_length, **variant['algo_kwargs']) algorithm.to(ptu.device) vae.to(ptu.device) algorithm.train()
def td3_experiment_offpolicy_online_vae(variant): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.torch.td3.td3 import TD3 from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.exploration_strategies.gaussian_and_epislon import \ GaussianAndEpislonStrategy from rlkit.torch.vae.online_vae_offpolicy_algorithm import OnlineVaeOffpolicyAlgorithm preprocess_rl_variant(variant) env = get_envs(variant) uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None) if uniform_dataset_fn: uniform_dataset = uniform_dataset_fn( **variant['generate_uniform_dataset_kwargs']) else: uniform_dataset = None observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=hidden_sizes, # **variant['policy_kwargs'] ) target_policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=hidden_sizes, # **variant['policy_kwargs'] ) es = GaussianAndEpislonStrategy( action_space=env.action_space, max_sigma=.2, min_sigma=.2, # constant sigma epsilon=.3, ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) vae = env.vae replay_buffer_class = variant.get("replay_buffer_class", OnlineVaeRelabelingBuffer) replay_buffer = replay_buffer_class(vae=env.vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) replay_buffer.representation_size = vae.representation_size vae_trainer_class = variant.get("vae_trainer_class", ConvVAETrainer) vae_trainer = vae_trainer_class(env.vae, **variant['online_vae_trainer_kwargs']) assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs" max_path_length = variant['max_path_length'] trainer = TD3(policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, target_policy=target_policy, **variant['td3_trainer_kwargs']) trainer = HERTrainer(trainer) eval_path_collector = VAEWrappedEnvPathCollector( variant['evaluation_goal_sampling_mode'], env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = VAEWrappedEnvPathCollector( variant['exploration_goal_sampling_mode'], env, expl_policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = OnlineVaeOffpolicyAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, vae=vae, vae_trainer=vae_trainer, uniform_dataset=uniform_dataset, max_path_length=max_path_length, **variant['algo_kwargs']) if variant.get("save_video", True): video_func = VideoSaveFunction( env, variant, ) algorithm.post_train_funcs.append(video_func) if variant['custom_goal_sampler'] == 'replay_buffer': env.custom_goal_sampler = replay_buffer.sample_buffer_goals algorithm.to(ptu.device) vae.to(ptu.device) algorithm.pretrain() algorithm.train()
def td3_experiment_online_vae_exploring(variant): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.torch.her.online_vae_joint_algo import OnlineVaeHerJointAlgo from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy from rlkit.torch.td3.td3 import TD3 from rlkit.torch.vae.vae_trainer import ConvVAETrainer preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) exploring_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) exploring_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) exploring_policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'], ) exploring_exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=exploring_policy, ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) variant["algo_kwargs"]["replay_buffer"] = replay_buffer if variant.get('use_replay_buffer_goals', False): env.replay_buffer = replay_buffer env.use_replay_buffer_goals = True vae_trainer_kwargs = variant.get('vae_trainer_kwargs') t = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], vae, beta=variant['online_vae_beta'], **vae_trainer_kwargs) control_algorithm = TD3(env=env, training_env=env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) exploring_algorithm = TD3(env=env, training_env=env, qf1=exploring_qf1, qf2=exploring_qf2, policy=exploring_policy, exploration_policy=exploring_exploration_policy, **variant['algo_kwargs']) assert 'vae_training_schedule' not in variant,\ "Just put it in joint_algo_kwargs" algorithm = OnlineVaeHerJointAlgo(vae=vae, vae_trainer=t, env=env, training_env=env, policy=policy, exploration_policy=exploration_policy, replay_buffer=replay_buffer, algo1=control_algorithm, algo2=exploring_algorithm, algo1_prefix="Control_", algo2_prefix="VAE_Exploration_", observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['joint_algo_kwargs']) algorithm.to(ptu.device) vae.to(ptu.device) if variant.get("save_video", True): policy.train(False) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.train()
def twin_sac_experiment_online_vae(variant): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.torch.networks import ConcatMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.vae.vae_trainer import ConvVAETrainer preprocess_rl_variant(variant) env = get_envs(variant) uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None) if uniform_dataset_fn: uniform_dataset = uniform_dataset_fn( **variant['generate_uniform_dataset_kwargs']) else: uniform_dataset = None observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) vae = env.vae replay_buffer_class = variant.get("replay_buffer_class", OnlineVaeRelabelingBuffer) replay_buffer = replay_buffer_class(vae=env.vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) vae_trainer_class = variant.get("vae_trainer_class", ConvVAETrainer) vae_trainer = vae_trainer_class(env.vae, **variant['online_vae_trainer_kwargs']) assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs" max_path_length = variant['max_path_length'] trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['twin_sac_trainer_kwargs']) trainer = HERTrainer(trainer) eval_path_collector = VAEWrappedEnvPathCollector( variant['evaluation_goal_sampling_mode'], env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = VAEWrappedEnvPathCollector( variant['exploration_goal_sampling_mode'], env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = OnlineVaeAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, vae=vae, vae_trainer=vae_trainer, uniform_dataset=uniform_dataset, max_path_length=max_path_length, **variant['algo_kwargs']) if variant.get("save_video", True): video_func = VideoSaveFunction( env, variant, ) algorithm.post_train_funcs.append(video_func) if variant['custom_goal_sampler'] == 'replay_buffer': env.custom_goal_sampler = replay_buffer.sample_buffer_goals algorithm.to(ptu.device) vae.to(ptu.device) algorithm.train()
def experiment(variant): env_params = ENV_PARAMS[variant['env']] env_mod_params = variant['env_mod'] variant.update(env_params) expl_env = NormalizedBoxEnv(variant['env_class'](env_mod_params)) eval_env = NormalizedBoxEnv(variant['env_class']({})) obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size M = variant['layer_size'] qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=[M, M], ) eval_policy = MakeDeterministic(policy) eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) replay_buffer = EnvReplayBuffer( variant['replay_buffer_size'], expl_env, ) trainer = SACTrainer(env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['trainer_kwargs']) if variant['collection_mode'] == 'online': expl_path_collector = MdpStepCollector( expl_env, policy, ) algorithm = TorchOnlineRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant[ 'num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant[ 'min_num_steps_before_training'], ) else: expl_path_collector = MdpPathCollector( expl_env, policy, ) algorithm = TorchBatchRLAlgorithmModEnv( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant[ 'num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant[ 'min_num_steps_before_training'], mod_env_epoch_schedule=variant['mod_env_epoch_schedule'], env_mod_dist=variant['mod_env_dist'], env_class=variant['env_class'], env_mod_params=variant['env_mod']) algorithm.to(ptu.device) algorithm.train()
def _use_disentangled_encoder_distance( max_path_length, encoder_kwargs, disentangled_qf_kwargs, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, evaluation_goal_sampling_mode, exploration_goal_sampling_mode, algo_kwargs, env_id=None, env_class=None, env_kwargs=None, encoder_key_prefix='encoder', encoder_input_prefix='state', latent_dim=2, reward_mode=EncoderWrappedEnv.ENCODER_DISTANCE_REWARD, # Video parameters save_video=True, save_video_kwargs=None, save_vf_heatmap=True, **kwargs): if save_video_kwargs is None: save_video_kwargs = {} if env_kwargs is None: env_kwargs = {} assert env_id or env_class vectorized = ( reward_mode == EncoderWrappedEnv.VECTORIZED_ENCODER_DISTANCE_REWARD) if env_id: import gym import multiworld multiworld.register_all_envs() raw_train_env = gym.make(env_id) raw_eval_env = gym.make(env_id) else: raw_eval_env = env_class(**env_kwargs) raw_train_env = env_class(**env_kwargs) raw_train_env.goal_sampling_mode = exploration_goal_sampling_mode raw_eval_env.goal_sampling_mode = evaluation_goal_sampling_mode raw_obs_dim = ( raw_train_env.observation_space.spaces['state_observation'].low.size) action_dim = raw_train_env.action_space.low.size encoder = ConcatMlp(input_size=raw_obs_dim, output_size=latent_dim, **encoder_kwargs) encoder = Identity() encoder.input_size = raw_obs_dim encoder.output_size = raw_obs_dim np_encoder = EncoderFromNetwork(encoder) train_env = EncoderWrappedEnv( raw_train_env, np_encoder, encoder_input_prefix, key_prefix=encoder_key_prefix, reward_mode=reward_mode, ) eval_env = EncoderWrappedEnv( raw_eval_env, np_encoder, encoder_input_prefix, key_prefix=encoder_key_prefix, reward_mode=reward_mode, ) observation_key = '{}_observation'.format(encoder_key_prefix) desired_goal_key = '{}_desired_goal'.format(encoder_key_prefix) achieved_goal_key = '{}_achieved_goal'.format(encoder_key_prefix) obs_dim = train_env.observation_space.spaces[observation_key].low.size goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size def make_qf(): return DisentangledMlpQf(encoder=encoder, preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, vectorized=vectorized, **disentangled_qf_kwargs) qf1 = make_qf() qf2 = make_qf() target_qf1 = make_qf() target_qf2 = make_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim, action_dim=action_dim, **policy_kwargs) replay_buffer = ObsDictRelabelingBuffer( env=train_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, vectorized=vectorized, **replay_buffer_kwargs) sac_trainer = SACTrainer(env=train_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) trainer = HERTrainer(sac_trainer) eval_path_collector = GoalConditionedPathCollector( eval_env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode='env', ) expl_path_collector = GoalConditionedPathCollector( train_env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode='env', ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=train_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action, return_individual_q_vals=True) add_heatmap = partial( add_heatmap_imgs_to_o_dict, v_function=v_function, vectorized=vectorized, ) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, full_o_postprocess_func=add_heatmap if save_vf_heatmap else None, ) img_keys = ['v_vals'] + [ 'v_vals_dim_{}'.format(dim) for dim in range(latent_dim) ] eval_video_func = get_save_video_function(rollout_function, eval_env, MakeDeterministic(policy), get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="eval", **save_video_kwargs) train_video_func = get_save_video_function(rollout_function, train_env, policy, get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="train", **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(train_video_func) algorithm.train()