def her_twin_sac_experiment(variant): env = variant['env_class'](**variant['env_kwargs']) observation_key = variant.get('observation_key', 'observation') desired_goal_key = variant.get('desired_goal_key', 'desired_goal') replay_buffer = ObsDictRelabelingBuffer(env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['replay_buffer_kwargs']) obs_dim = env.observation_space.spaces['observation'].low.size action_dim = env.action_space.low.size goal_dim = env.observation_space.spaces['desired_goal'].low.size if variant['normalize']: env = NormalizedBoxEnv(env) qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim, output_size=1, **variant['qf_kwargs']) qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim, output_size=1, **variant['qf_kwargs']) vf = ConcatMlp(input_size=obs_dim + goal_dim, output_size=1, **variant['vf_kwargs']) policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim, action_dim=action_dim, **variant['policy_kwargs']) algorithm = HerTwinSac(env, qf1=qf1, qf2=qf2, vf=vf, policy=policy, replay_buffer=replay_buffer, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs']) if ptu.gpu_enabled(): qf1.to(ptu.device) qf2.to(ptu.device) vf.to(ptu.device) policy.to(ptu.device) algorithm.to(ptu.device) algorithm.train()
def grill_her_sac_experiment(variant): env = variant["env_class"](**variant['env_kwargs']) render = variant["render"] rdim = variant["rdim"] vae_path = variant["vae_paths"][str(rdim)] reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) if init_camera is None: camera_name = "topview" else: camera_name = None env = ImageEnv( env, 84, init_camera=init_camera, camera_name=camera_name, transpose=True, normalize=True, ) env = VAEWrappedEnv(env, vae_path, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if variant['normalize']: env = NormalizedBoxEnv(env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) vf = ConcatMlp( input_size=obs_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") testing_env = pickle.loads(pickle.dumps(env)) testing_env.mode(testing_mode) training_env = pickle.loads(pickle.dumps(env)) training_env.mode(training_mode) relabeling_env = pickle.loads(pickle.dumps(env)) relabeling_env.mode(training_mode) relabeling_env.disable_render() video_vae_env = pickle.loads(pickle.dumps(env)) video_vae_env.mode("video_vae") video_goal_env = pickle.loads(pickle.dumps(env)) video_goal_env.mode("video_env") replay_buffer = ObsDictRelabelingBuffer( env=relabeling_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_kwargs']) variant["algo_kwargs"]["replay_buffer"] = replay_buffer algorithm = HerSac(testing_env, training_env=training_env, qf=qf, vf=vf, policy=policy, render=render, render_during_eval=render, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs']) if ptu.gpu_enabled(): print("using GPU") qf.to(ptu.device) vf.to(ptu.device) policy.to(ptu.device) algorithm.to(ptu.device) for e in [testing_env, training_env, video_vae_env, video_goal_env]: e.vae.to(ptu.device) algorithm.train() if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) filename = osp.join(logdir, 'video_final_env.mp4') rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) dump_video(video_goal_env, policy, filename, rollout_function) filename = osp.join(logdir, 'video_final_vae.mp4') dump_video(video_vae_env, policy, filename, rollout_function)
def main( env_name, seed, deterministic, traj_prior, start_ft_after, ft_steps, avoid_freezing_z, lr, batch_size, avoid_loading_critics ): config = "configs/{}.json".format(env_name) variant = default_config if config: with open(osp.join(config)) as f: exp_params = json.load(f) variant = deep_update_dict(exp_params, variant) exp_name = variant['env_name'] print("Experiment: {}".format(exp_name)) env = NormalizedBoxEnv(ENVS[exp_name](**variant['env_params'])) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) print("Observation space:") print(env.observation_space) print(obs_dim) print("Action space:") print(env.action_space) print(action_dim) print("-" * 10) # instantiate networks latent_dim = variant['latent_size'] reward_dim = 1 context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant['algo_params']['use_next_obs_in_context'] else obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 if variant['algo_params']['use_information_bottleneck'] else latent_dim net_size = variant['net_size'] recurrent = variant['algo_params']['recurrent'] encoder_model = RecurrentEncoder if recurrent else MlpEncoder context_encoder = encoder_model( hidden_sizes=[200, 200, 200], input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, ) qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) target_qf1 = qf1.copy() target_qf2 = qf2.copy() policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size, net_size], obs_dim=obs_dim + latent_dim, latent_dim=latent_dim, action_dim=action_dim, ) agent = PEARLAgent( latent_dim, context_encoder, policy, **variant['algo_params'] ) # deterministic eval if deterministic: agent = MakeDeterministic(agent) # load trained weights (otherwise simulate random policy) path_to_exp = "output/{}/pearl_{}".format(env_name, seed-1) print("Based on experiment: {}".format(path_to_exp)) context_encoder.load_state_dict(torch.load(os.path.join(path_to_exp, 'context_encoder.pth'))) policy.load_state_dict(torch.load(os.path.join(path_to_exp, 'policy.pth'))) if not avoid_loading_critics: qf1.load_state_dict(torch.load(os.path.join(path_to_exp, 'qf1.pth'))) qf2.load_state_dict(torch.load(os.path.join(path_to_exp, 'qf2.pth'))) target_qf1.load_state_dict(torch.load(os.path.join(path_to_exp, 'target_qf1.pth'))) target_qf2.load_state_dict(torch.load(os.path.join(path_to_exp, 'target_qf2.pth'))) # optional GPU mode ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id']) if ptu.gpu_enabled(): agent.to(device) policy.to(device) context_encoder.to(device) qf1.to(device) qf2.to(device) target_qf1.to(device) target_qf2.to(device) helper = PEARLFineTuningHelper( env=env, agent=agent, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, num_exp_traj_eval=traj_prior, start_fine_tuning=start_ft_after, fine_tuning_steps=ft_steps, should_freeze_z=(not avoid_freezing_z), replay_buffer_size=int(1e6), batch_size=batch_size, discount=0.99, policy_lr=lr, qf_lr=lr, temp_lr=lr, target_entropy=-action_dim, ) helper.fine_tune(variant=variant, seed=seed)