def skewfit_experiment(variant): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer \ import OnlineVaeRelabelingBuffer from rlkit.torch.networks import FlattenMlp from rlkit.torch.sac.policies import TanhGaussianPolicy import rlkit.torch.vae.vae_schedules as vae_schedules #### getting parameter for training VAE and RIG env = get_envs(variant) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) replay_buffer_kwargs = variant.get( 'replay_buffer_kwargs', dict( start_skew_epoch=10, max_size=int(100000), fraction_goals_rollout_goals=0.2, fraction_goals_env_goals=0.5, exploration_rewards_type='None', vae_priority_type='vae_prob', priority_function_kwargs=dict( sampling_method='importance_sampling', decoder_distribution='gaussian_identity_variance', num_latents_to_sample=10, ), power=0, relabeling_goal_sampling_mode='vae_prior', )) online_vae_trainer_kwargs = variant.get('online_vae_trainer_kwargs', dict(beta=20, lr=1e-3)) max_path_length = variant.get('max_path_length', 50) algo_kwargs = variant.get( 'algo_kwargs', dict( batch_size=1024, num_epochs=1000, num_eval_steps_per_epoch=500, num_expl_steps_per_train_loop=500, num_trains_per_train_loop=1000, min_num_steps_before_training=10000, vae_training_schedule=vae_schedules.custom_schedule_2, oracle_data=False, vae_save_period=50, parallel_vae_train=False, )) twin_sac_trainer_kwargs = variant.get( 'twin_sac_trainer_kwargs', dict( discount=0.99, reward_scale=1, soft_target_tau=1e-3, target_update_period=1, # 1 use_automatic_entropy_tuning=True, )) ############################################################################ qf1 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) qf2 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) target_qf1 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) target_qf2 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes) vae = variant['vae_model'] # create a replay buffer for training an online VAE replay_buffer = OnlineVaeRelabelingBuffer( vae=vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **replay_buffer_kwargs) # create an online vae_trainer to train vae on the fly vae_trainer = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], vae, **online_vae_trainer_kwargs) # create a SACTrainer to learn a soft Q-function and appropriate policy trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **twin_sac_trainer_kwargs) trainer = HERTrainer(trainer) eval_path_collector = VAEWrappedEnvPathCollector( variant.get('evaluation_goal_sampling_mode', 'reset_of_env'), env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = VAEWrappedEnvPathCollector( variant.get('exploration_goal_sampling_mode', 'vae_prior'), env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = OnlineVaeAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, vae=vae, vae_trainer=vae_trainer, max_path_length=max_path_length, **algo_kwargs) if variant['custom_goal_sampler'] == 'replay_buffer': env.custom_goal_sampler = replay_buffer.sample_buffer_goals algorithm.to(ptu.device) vae.to(ptu.device) algorithm.train()
def skewfit_experiment(variant, other_variant): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.torch.networks import FlattenMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.vae.vae_trainer import ConvVAETrainer skewfit_preprocess_variant(variant) env = get_envs(variant) uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None) if uniform_dataset_fn: uniform_dataset = uniform_dataset_fn( **variant['generate_uniform_dataset_kwargs']) else: uniform_dataset = None observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=env.vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) vae_trainer = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], env.vae, other_variant, **variant['online_vae_trainer_kwargs']) assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs" max_path_length = variant['max_path_length'] trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['twin_sac_trainer_kwargs']) trainer = HERTrainer(trainer) eval_path_collector = VAEWrappedEnvPathCollector( variant['evaluation_goal_sampling_mode'], env, MakeDeterministic(policy), max_path_length, other_variant=other_variant, observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = VAEWrappedEnvPathCollector( variant['exploration_goal_sampling_mode'], env, policy, max_path_length, other_variant=other_variant, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = OnlineVaeAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, vae=vae, vae_trainer=vae_trainer, uniform_dataset=uniform_dataset, max_path_length=max_path_length, **variant['algo_kwargs']) if variant['custom_goal_sampler'] == 'replay_buffer': env.custom_goal_sampler = replay_buffer.sample_buffer_goals algorithm.to(ptu.device) vae.to(ptu.device) algorithm.train()
def skewfit_experiment(cfgs): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.torch.networks import FlattenMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.vae.vae_trainer import ConvVAETrainer skewfit_preprocess_variant(cfgs) env = get_envs(cfgs) # TODO uniform_dataset_fn = cfgs.GENERATE_VAE_DATASET.get( 'uniform_dataset_generator', None) if uniform_dataset_fn: uniform_dataset = uniform_dataset_fn( **cfgs.GENERATE_VAE_DATASET.generate_uniform_dataset_kwargs) else: uniform_dataset = None observation_key = cfgs.SKEW_FIT.get('observation_key', 'latent_observation') desired_goal_key = cfgs.SKEW_FIT.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = cfgs.Q_FUNCTION.get('hidden_sizes', [400, 300]) qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=cfgs.POLICY.get('hidden_sizes', [400, 300]), ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=env.vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, priority_function_kwargs=cfgs.PRIORITY_FUNCTION, **cfgs.REPLAY_BUFFER) vae_trainer = ConvVAETrainer( cfgs.VAE_TRAINER.train_data, cfgs.VAE_TRAINER.test_data, env.vae, beta=cfgs.VAE_TRAINER.beta, lr=cfgs.VAE_TRAINER.lr, ) # assert 'vae_training_schedule' not in cfgs, "Just put it in algo_kwargs" max_path_length = cfgs.SKEW_FIT.max_path_length trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **cfgs.TWIN_SAC_TRAINER) trainer = HERTrainer(trainer) eval_path_collector = VAEWrappedEnvPathCollector( cfgs.SKEW_FIT.evaluation_goal_sampling_mode, env, MakeDeterministic(policy), decode_goals=True, # TODO check this observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = VAEWrappedEnvPathCollector( cfgs.SKEW_FIT.exploration_goal_sampling_mode, env, policy, decode_goals=True, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = OnlineVaeAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, vae=vae, vae_trainer=vae_trainer, uniform_dataset=uniform_dataset, # TODO used in test vae max_path_length=max_path_length, parallel_vae_train=cfgs.VAE_TRAINER.parallel_train, **cfgs.ALGORITHM) if cfgs.SKEW_FIT.custom_goal_sampler == 'replay_buffer': env.custom_goal_sampler = replay_buffer.sample_buffer_goals algorithm.to(ptu.device) vae.to(ptu.device) algorithm.train()