def image_based_goal_conditioned_sac_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, cnn_kwargs, policy_type='tanh-normal', env_id=None, env_class=None, env_kwargs=None, exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, reward_type='state_distance', env_renderer_kwargs=None, # Data augmentations apply_random_crops=False, random_crop_pixel_shift=4, # Video parameters save_video=True, save_video_kwargs=None, video_renderer_kwargs=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not env_renderer_kwargs: env_renderer_kwargs = {} if not video_renderer_kwargs: video_renderer_kwargs = {} img_observation_key = 'image_observation' img_desired_goal_key = 'image_desired_goal' state_observation_key = 'state_observation' state_desired_goal_key = 'state_desired_goal' state_achieved_goal_key = 'state_achieved_goal' sample_context_from_obs_dict_fn = RemapKeyFn({ 'image_desired_goal': 'image_observation', 'state_desired_goal': 'state_observation', }) def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode, renderer): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_env.goal_sampling_mode = goal_sampling_mode state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=img_desired_goal_key, renderer=renderer, ) goal_distribution = PresampledDistribution(image_goal_distribution, 5000) img_env = InsertImageEnv(state_env, renderer=renderer) if reward_type == 'state_distance': reward_fn = ContextualRewardFnFromMultitaskEnv( env=state_env, achieved_goal_from_observation=IndexIntoAchievedGoal( 'state_observation'), desired_goal_key=state_desired_goal_key, achieved_goal_key=state_achieved_goal_key, ) elif reward_type == 'pixel_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( img_observation_key), desired_goal_key=img_desired_goal_key, ) else: raise ValueError(reward_type) env = ContextualEnv( img_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=img_observation_key, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn env_renderer = EnvRenderer(**env_renderer_kwargs) expl_env, expl_context_distrib, expl_reward = setup_contextual_env( env_id, env_class, env_kwargs, exploration_goal_sampling_mode, env_renderer) eval_env, eval_context_distrib, eval_reward = setup_contextual_env( env_id, env_class, env_kwargs, evaluation_goal_sampling_mode, env_renderer) action_dim = expl_env.action_space.low.size if env_renderer.output_image_format == 'WHC': img_width, img_height, img_num_channels = ( expl_env.observation_space[img_observation_key].shape) elif env_renderer.output_image_format == 'CHW': img_num_channels, img_height, img_width = ( expl_env.observation_space[img_observation_key].shape) else: raise ValueError(env_renderer.output_image_format) def create_qf(): cnn = BasicCNN(input_width=img_width, input_height=img_height, input_channels=img_num_channels, **cnn_kwargs) joint_cnn = ApplyConvToStateAndGoalImage(cnn) return basic.MultiInputSequential( ApplyToObs(joint_cnn), basic.FlattenEachParallel(), ConcatMlp(input_size=joint_cnn.output_size + action_dim, output_size=1, **qf_kwargs)) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() cnn = BasicCNN(input_width=img_width, input_height=img_height, input_channels=img_num_channels, **cnn_kwargs) joint_cnn = ApplyConvToStateAndGoalImage(cnn) policy_obs_dim = joint_cnn.output_size if policy_type == 'normal': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], **policy_kwargs)) policy = PolicyFromDistributionGenerator(Gaussian(obs_processor)) elif policy_type == 'tanh-normal': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], **policy_kwargs)) policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor)) elif policy_type == 'normal-tanh-mean': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], output_activations=['tanh', 'identity'], **policy_kwargs)) policy = PolicyFromDistributionGenerator(Gaussian(obs_processor)) else: raise ValueError("Unknown policy type: {}".format(policy_type)) if apply_random_crops: pad = BatchPad( env_renderer.output_image_format, random_crop_pixel_shift, random_crop_pixel_shift, ) crop = JointRandomCrop( env_renderer.output_image_format, env_renderer.image_shape, ) def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[img_desired_goal_key] obs_padded = pad(obs) next_obs_padded = pad(next_obs) context_padded = pad(context) obs_aug, context_aug = crop(obs_padded, context_padded) next_obs_aug, next_context_aug = crop(next_obs_padded, context_padded) batch['observations'] = np.concatenate([obs_aug, context_aug], axis=1) batch['next_observations'] = np.concatenate( [next_obs_aug, next_context_aug], axis=1) return batch else: def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[img_desired_goal_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[img_desired_goal_key, state_desired_goal_key], observation_key=img_observation_key, observation_keys=[img_observation_key, state_observation_key], context_distribution=eval_context_distrib, sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) trainer = SACTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) video_renderer = EnvRenderer(**video_renderer_kwargs) video_eval_env = InsertImageEnv(eval_env, renderer=video_renderer, image_key='video_observation') video_expl_env = InsertImageEnv(expl_env, renderer=video_renderer, image_key='video_observation') video_eval_env = ContextualEnv( video_eval_env, context_distribution=eval_env.context_distribution, reward_fn=lambda *_: np.array([0]), observation_key=img_observation_key, ) video_expl_env = ContextualEnv( video_expl_env, context_distribution=expl_env.context_distribution, reward_fn=lambda *_: np.array([0]), observation_key=img_observation_key, ) eval_video_func = get_save_video_function( rollout_function, video_eval_env, MakeDeterministic(policy), tag="eval", imsize=video_renderer.image_shape[1], image_formats=[ env_renderer.output_image_format, env_renderer.output_image_format, video_renderer.output_image_format, ], keys_to_show=[ 'image_desired_goal', 'image_observation', 'video_observation' ], **save_video_kwargs) expl_video_func = get_save_video_function( rollout_function, video_expl_env, exploration_policy, tag="xplor", imsize=video_renderer.image_shape[1], image_formats=[ env_renderer.output_image_format, env_renderer.output_image_format, video_renderer.output_image_format, ], keys_to_show=[ 'image_desired_goal', 'image_observation', 'video_observation' ], **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(expl_video_func) algorithm.train()
def sac_on_gym_goal_env_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, env_id=None, env_class=None, env_kwargs=None, observation_key='observation', desired_goal_key='desired_goal', achieved_goal_key='achieved_goal', exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, # Video parameters save_video=True, save_video_kwargs=None, renderer_kwargs=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not renderer_kwargs: renderer_kwargs = {} context_key = desired_goal_key sample_context_from_obs_dict_fn = RemapKeyFn( {context_key: achieved_goal_key}) def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode): env = get_gym_env( env_id, env_class=env_class, env_kwargs=env_kwargs, unwrap_timed_envs=True, ) env.goal_sampling_mode = goal_sampling_mode goal_distribution = GoalDictDistributionFromGymGoalEnv( env, desired_goal_key=desired_goal_key, ) distance_fn = L2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( achieved_goal_key, ), desired_goal_key=desired_goal_key, ) if (isinstance(env, robotics.FetchReachEnv) or isinstance(env, robotics.FetchPushEnv) or isinstance(env, robotics.FetchPickAndPlaceEnv) or isinstance(env, robotics.FetchSlideEnv)): success_threshold = 0.05 else: raise TypeError("I don't know the success threshold of env ", env) reward_fn = ThresholdDistanceReward(distance_fn, success_threshold) diag_fn = GenericGoalConditionedContextualDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, exploration_goal_sampling_mode) eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, evaluation_goal_sampling_mode) obs_dim = (expl_env.observation_space.spaces[observation_key].low.size + expl_env.observation_space.spaces[context_key].low.size) action_dim = expl_env.action_space.low.size def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs) def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[context_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[context_key], observation_keys_to_save=[observation_key, achieved_goal_key], context_distribution=eval_context_distrib, sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) trainer = SACTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=observation_key, context_keys_for_policy=[context_key], ) exploration_policy = create_exploration_policy(policy=policy, env=expl_env, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=observation_key, context_keys_for_policy=[context_key], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: # Setting the goal like this is discourage, but the Fetch environment # are designed to visualize the goals by setting their goal parameter. def set_goal_for_visualization(env, policy, o): goal = o[desired_goal_key] print(goal) env.unwrapped.goal = goal rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=observation_key, context_keys_for_policy=[context_key], reset_callback=set_goal_for_visualization, ) renderer = GymEnvRenderer(**renderer_kwargs) def add_images(env, context_distribution): state_env = env.env img_env = InsertImageEnv( state_env, renderer=renderer, image_key='image_observation', ) return ContextualEnv( img_env, context_distribution=context_distribution, reward_fn=eval_reward, observation_key=observation_key, update_env_info_fn=delete_info, ) img_eval_env = add_images(eval_env, eval_context_distrib) img_expl_env = add_images(expl_env, expl_context_distrib) eval_video_func = get_save_video_function( rollout_function, img_eval_env, MakeDeterministic(policy), tag="eval", imsize=renderer.image_chw[1], image_format=renderer.output_image_format, keys_to_show=['image_observation'], **save_video_kwargs) expl_video_func = get_save_video_function( rollout_function, img_expl_env, exploration_policy, tag="train", imsize=renderer.image_chw[1], image_format=renderer.output_image_format, keys_to_show=['image_observation'], **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(expl_video_func) algorithm.train()
def goal_conditioned_sac_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, env_id=None, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, # Video parameters save_video=True, save_video_kwargs=None, renderer_kwargs=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not renderer_kwargs: renderer_kwargs = {} context_key = desired_goal_key sample_context_from_obs_dict_fn = RemapKeyFn({context_key: observation_key}) def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, exploration_goal_sampling_mode ) eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, evaluation_goal_sampling_mode ) obs_dim = ( expl_env.observation_space.spaces[observation_key].low.size + expl_env.observation_space.spaces[context_key].low.size ) action_dim = expl_env.action_space.low.size def create_qf(): return ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs ) def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[context_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[context_key], observation_keys=[observation_key], context_distribution=eval_context_distrib, sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs ) trainer = SACTrainer( env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs ) eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=observation_key, context_keys_for_policy=[context_key], ) exploration_policy = create_exploration_policy( policy=policy, env=expl_env, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=observation_key, context_keys_for_policy=[context_key], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs ) algorithm.to(ptu.device) if save_video: rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=observation_key, context_keys_for_policy=[context_key], ) renderer = EnvRenderer(**renderer_kwargs) def add_images(env, state_distribution): state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=renderer, ) img_env = InsertImageEnv(state_env, renderer=renderer) return ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=eval_reward, observation_key=observation_key, update_env_info_fn=delete_info, ) img_eval_env = add_images(eval_env, eval_context_distrib) img_expl_env = add_images(expl_env, expl_context_distrib) eval_video_func = get_save_video_function( rollout_function, img_eval_env, MakeDeterministic(policy), tag="eval", imsize=renderer.width, image_format=renderer.output_image_format, **save_video_kwargs ) expl_video_func = get_save_video_function( rollout_function, img_expl_env, exploration_policy, tag="train", imsize=renderer.width, image_format=renderer.output_image_format, **save_video_kwargs ) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(expl_video_func) algorithm.train()
def her_sac_experiment( max_path_length, qf_kwargs, twin_sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, evaluation_goal_sampling_mode, exploration_goal_sampling_mode, algo_kwargs, save_video=True, env_id=None, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', # Video parameters save_video_kwargs=None, exploration_policy_kwargs=None, **kwargs ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from rlkit.torch.networks import ConcatMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm if not save_video_kwargs: save_video_kwargs = {} if env_kwargs is None: env_kwargs = {} assert env_id or env_class if env_id: import gym import multiworld multiworld.register_all_envs() train_env = gym.make(env_id) eval_env = gym.make(env_id) else: eval_env = env_class(**env_kwargs) train_env = env_class(**env_kwargs) obs_dim = ( train_env.observation_space.spaces[observation_key].low.size + train_env.observation_space.spaces[desired_goal_key].low.size ) action_dim = train_env.action_space.low.size qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **qf_kwargs ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs ) replay_buffer = ObsDictRelabelingBuffer( env=train_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **replay_buffer_kwargs ) trainer = SACTrainer( env=train_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **twin_sac_trainer_kwargs ) trainer = HERTrainer(trainer) eval_path_collector = GoalConditionedPathCollector( eval_env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=evaluation_goal_sampling_mode, ) exploration_policy = create_exploration_policy( train_env, policy, **exploration_policy_kwargs) expl_path_collector = GoalConditionedPathCollector( train_env, exploration_policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=exploration_goal_sampling_mode, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=train_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs ) algorithm.to(ptu.device) if save_video: rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, return_dict_obs=True, ) eval_video_func = get_save_video_function( rollout_function, eval_env, MakeDeterministic(policy), tag="eval", **save_video_kwargs ) train_video_func = get_save_video_function( rollout_function, train_env, exploration_policy, tag="expl", **save_video_kwargs ) # algorithm.post_train_funcs.append(plot_buffer_function( # save_video_period, 'state_achieved_goal')) # algorithm.post_train_funcs.append(plot_buffer_function( # save_video_period, 'state_desired_goal')) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(train_video_func) algorithm.train()
def rig_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, train_vae_kwargs, env_id=None, env_class=None, env_kwargs=None, observation_key='latent_observation', desired_goal_key='latent_desired_goal', state_goal_key='state_desired_goal', state_observation_key='state_observation', image_goal_key='image_desired_goal', exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, # Video parameters save_video=True, save_video_kwargs=None, renderer_kwargs=None, imsize=48, pretrained_vae_path="", init_camera=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not renderer_kwargs: renderer_kwargs = {} renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) encoded_env = EncoderWrappedEnv( img_env, model, dict(image_observation="latent_observation", ), ) if goal_sampling_mode == "vae_prior": latent_goal_distribution = PriorDistribution( model.representation_size, desired_goal_key, ) diagnostics = StateImageGoalDiagnosticsFn({}, ) elif goal_sampling_mode == "reset_of_env": state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_goal_env, desired_goal_keys=[state_goal_key], ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=image_goal_key, renderer=renderer, ) latent_goal_distribution = AddLatentDistribution( image_goal_distribution, image_goal_key, desired_goal_key, model, ) if hasattr(state_goal_env, 'goal_conditioned_diagnostics'): diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics( state_goal_env.goal_conditioned_diagnostics, desired_goal_key=state_goal_key, observation_key=state_observation_key, ) else: state_goal_env.get_contextual_diagnostics diagnostics = state_goal_env.get_contextual_diagnostics else: raise NotImplementedError('unknown goal sampling method: %s' % goal_sampling_mode) reward_fn = DistanceRewardFn( observation_key=observation_key, desired_goal_key=desired_goal_key, ) env = ContextualEnv( encoded_env, context_distribution=latent_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, latent_goal_distribution, reward_fn if pretrained_vae_path: model = load_local_or_remote_file(pretrained_vae_path) else: model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class, imsize, init_camera) expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, exploration_goal_sampling_mode) eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, evaluation_goal_sampling_mode) context_key = desired_goal_key obs_dim = (expl_env.observation_space.spaces[observation_key].low.size + expl_env.observation_space.spaces[context_key].low.size) action_dim = expl_env.action_space.low.size def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs) def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[context_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[context_key], observation_keys=[observation_key], observation_key=observation_key, context_distribution=expl_context_distrib, sample_context_from_obs_dict_fn=RemapKeyFn( {context_key: observation_key}), reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) trainer = SACTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=observation_key, context_keys_for_policy=[ context_key, ], ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=observation_key, context_keys_for_policy=[ context_key, ], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: expl_video_func = RIGVideoSaveFunction( model, expl_path_collector, "train", decode_goal_image_key="image_decoded_goal", reconstruction_key="image_reconstruction", rows=2, columns=5, unnormalize=True, # max_path_length=200, imsize=48, image_format=renderer.output_image_format, **save_video_kwargs) algorithm.post_train_funcs.append(expl_video_func) eval_video_func = RIGVideoSaveFunction( model, eval_path_collector, "eval", goal_image_key=image_goal_key, decode_goal_image_key="image_decoded_goal", reconstruction_key="image_reconstruction", num_imgs=4, rows=2, columns=5, unnormalize=True, # max_path_length=200, imsize=48, image_format=renderer.output_image_format, **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.train()
def probabilistic_goal_reaching_experiment( max_path_length, qf_kwargs, policy_kwargs, pgr_trainer_kwargs, replay_buffer_kwargs, algo_kwargs, env_id, discount_factor, reward_type, # Dynamics model dynamics_model_version, dynamics_model_config, dynamics_delta_model_config=None, dynamics_adam_config=None, dynamics_ensemble_kwargs=None, # Discount model learn_discount_model=False, discount_adam_config=None, discount_model_config=None, prior_discount_weight_schedule_kwargs=None, # Environment env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', exploration_policy_kwargs=None, action_noise_scale=0., num_presampled_goals=4096, success_threshold=0.05, # Video / visualization parameters save_video=True, save_video_kwargs=None, video_renderer_kwargs=None, plot_renderer_kwargs=None, eval_env_ids=None, # Debugging params visualize_dynamics=False, visualize_discount_model=False, visualize_all_plots=False, plot_discount=False, plot_reward=False, plot_bootstrap_value=False, # env specific-params normalize_distances_for_full_state_ant=False, ): if dynamics_ensemble_kwargs is None: dynamics_ensemble_kwargs = {} if eval_env_ids is None: eval_env_ids = {'eval': env_id} if discount_model_config is None: discount_model_config = {} if dynamics_delta_model_config is None: dynamics_delta_model_config = {} if dynamics_adam_config is None: dynamics_adam_config = {} if discount_adam_config is None: discount_adam_config = {} if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not video_renderer_kwargs: video_renderer_kwargs = {} if not plot_renderer_kwargs: plot_renderer_kwargs = video_renderer_kwargs.copy() plot_renderer_kwargs['dpi'] = 48 context_key = desired_goal_key stub_env = get_gym_env( env_id, env_class=env_class, env_kwargs=env_kwargs, unwrap_timed_envs=True, ) is_gym_env = ( isinstance(stub_env, FetchEnv) or isinstance(stub_env, AntXYGoalEnv) or isinstance(stub_env, AntFullPositionGoalEnv) # or isinstance(stub_env, HopperFullPositionGoalEnv) ) is_ant_full_pos = isinstance(stub_env, AntFullPositionGoalEnv) if is_gym_env: achieved_goal_key = desired_goal_key.replace('desired', 'achieved') ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key] elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ): achieved_goal_key = desired_goal_key.replace('desired', 'achieved') ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key] else: achieved_goal_key = observation_key ob_keys_to_save_in_buffer = [observation_key] # TODO move all env-specific code to other file if isinstance(stub_env, SawyerDoorHookEnv): init_camera = sawyer_door_env_camera_v0 elif isinstance(stub_env, SawyerPushAndReachXYEnv): init_camera = sawyer_init_camera_zoomed_in elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ): init_camera = sawyer_pick_and_place_camera else: init_camera = None full_ob_space = stub_env.observation_space action_space = stub_env.action_space state_to_goal = StateToGoalFn(stub_env) dynamics_model = create_goal_dynamics_model( full_ob_space[observation_key], action_space, full_ob_space[achieved_goal_key], dynamics_model_version, state_to_goal, dynamics_model_config, dynamics_delta_model_config, ensemble_model_kwargs=dynamics_ensemble_kwargs, ) sample_context_from_obs_dict_fn = RemapKeyFn( {context_key: achieved_goal_key}) def contextual_env_distrib_reward(_env_id, _env_class=None, _env_kwargs=None): base_env = get_gym_env( _env_id, env_class=env_class, env_kwargs=env_kwargs, unwrap_timed_envs=True, ) if init_camera: base_env.initialize_camera(init_camera) if (isinstance(stub_env, AntFullPositionGoalEnv) and normalize_distances_for_full_state_ant): base_env = NormalizeAntFullPositionGoalEnv(base_env) normalize_env = base_env else: normalize_env = None env = NoisyAction(base_env, action_noise_scale) diag_fns = [] if is_gym_env: goal_distribution = GoalDictDistributionFromGymGoalEnv( env, desired_goal_key=desired_goal_key, ) diag_fns.append( GenericGoalConditionedContextualDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, )) else: goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) diag_fns.append( GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, )) if isinstance(stub_env, AntFullPositionGoalEnv): diag_fns.append( AntFullPositionGoalEnvDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, normalize_env=normalize_env, )) # if isinstance(stub_env, HopperFullPositionGoalEnv): # diag_fns.append( # HopperFullPositionGoalEnvDiagnostics( # desired_goal_key=desired_goal_key, # achieved_goal_key=achieved_goal_key, # success_threshold=success_threshold, # ) # ) achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, ) if reward_type == 'sparse': distance_fn = L2Distance( achieved_goal_from_observation=achieved_from_ob, desired_goal_key=desired_goal_key, ) reward_fn = ThresholdDistanceReward(distance_fn, success_threshold) elif reward_type == 'negative_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=achieved_from_ob, desired_goal_key=desired_goal_key, ) else: reward_fn = ProbabilisticGoalRewardFn( dynamics_model, state_key=observation_key, context_key=context_key, reward_type=reward_type, discount_factor=discount_factor, ) goal_distribution = PresampledDistribution(goal_distribution, num_presampled_goals) final_env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=diag_fns, update_env_info_fn=delete_info, ) return final_env, goal_distribution, reward_fn expl_env, expl_context_distrib, reward_fn = contextual_env_distrib_reward( env_id, env_class, env_kwargs, ) obs_dim = (expl_env.observation_space.spaces[observation_key].low.size + expl_env.observation_space.spaces[context_key].low.size) action_dim = expl_env.action_space.low.size def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() def create_policy(): obs_processor = MultiHeadedMlp(input_size=obs_dim, output_sizes=[action_dim, action_dim], **policy_kwargs) return PolicyFromDistributionGenerator(TanhGaussian(obs_processor)) policy = create_policy() def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict, new_contexts): obs = batch['observations'] next_obs = batch['next_observations'] batch['original_observations'] = obs batch['original_next_observations'] = next_obs context = batch[context_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=expl_env, context_keys=[context_key], observation_keys=ob_keys_to_save_in_buffer, context_distribution=expl_context_distrib, sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn, reward_fn=reward_fn, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) def create_trainer(): trainers = OrderedDict() if learn_discount_model: discount_model = create_discount_model( ob_space=stub_env.observation_space[observation_key], goal_space=stub_env.observation_space[context_key], action_space=stub_env.action_space, model_kwargs=discount_model_config) optimizer = optim.Adam(discount_model.parameters(), **discount_adam_config) discount_trainer = DiscountModelTrainer( discount_model, optimizer, observation_key='observations', next_observation_key='original_next_observations', goal_key=context_key, state_to_goal_fn=state_to_goal, ) trainers['discount_trainer'] = discount_trainer else: discount_model = None if prior_discount_weight_schedule_kwargs is not None: schedule = create_schedule(**prior_discount_weight_schedule_kwargs) else: schedule = None pgr_trainer = PGRTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, discount=discount_factor, discount_model=discount_model, prior_discount_weight_schedule=schedule, **pgr_trainer_kwargs) trainers[''] = pgr_trainer optimizers = [ pgr_trainer.qf1_optimizer, pgr_trainer.qf2_optimizer, pgr_trainer.alpha_optimizer, pgr_trainer.policy_optimizer, ] if dynamics_model_version in { 'learned_model', 'learned_model_ensemble', 'learned_model_laplace', 'learned_model_laplace_global_variance', 'learned_model_gaussian_global_variance', }: model_opt = optim.Adam(dynamics_model.parameters(), **dynamics_adam_config) elif dynamics_model_version in { 'fixed_standard_laplace', 'fixed_standard_gaussian', }: model_opt = None else: raise NotImplementedError() model_trainer = GenerativeGoalDynamicsModelTrainer( dynamics_model, model_opt, state_to_goal=state_to_goal, observation_key='original_observations', next_observation_key='original_next_observations', ) trainers['dynamics_trainer'] = model_trainer optimizers.append(model_opt) return JointTrainer(trainers), pgr_trainer trainer, pgr_trainer = create_trainer() eval_policy = MakeDeterministic(policy) def create_eval_path_collector(some_eval_env): return ContextualPathCollector( some_eval_env, eval_policy, observation_key=observation_key, context_keys_for_policy=[context_key], ) path_collectors = dict() eval_env_name_to_env_and_context_distrib = dict() for name, extra_env_id in eval_env_ids.items(): env, context_distrib, _ = contextual_env_distrib_reward(extra_env_id) path_collectors[name] = create_eval_path_collector(env) eval_env_name_to_env_and_context_distrib[name] = (env, context_distrib) eval_path_collector = JointPathCollector(path_collectors) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=observation_key, context_keys_for_policy=[context_key], ) def get_eval_diagnostics(key_to_paths): stats = OrderedDict() for eval_env_name, paths in key_to_paths.items(): env, _ = eval_env_name_to_env_and_context_distrib[eval_env_name] stats.update( add_prefix( env.get_diagnostics(paths), eval_env_name, divider='/', )) stats.update( add_prefix( eval_util.get_generic_path_information(paths), eval_env_name, divider='/', )) return stats algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=None, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, evaluation_get_diagnostic_functions=[get_eval_diagnostics], **algo_kwargs) algorithm.to(ptu.device) if normalize_distances_for_full_state_ant and is_ant_full_pos: qpos_weights = expl_env.unwrapped.presampled_qpos.std(axis=0) else: qpos_weights = None if save_video: if is_gym_env: video_renderer = GymEnvRenderer(**video_renderer_kwargs) def set_goal_for_visualization(env, policy, o): goal = o[desired_goal_key] if normalize_distances_for_full_state_ant and is_ant_full_pos: unnormalized_goal = goal * qpos_weights env.unwrapped.goal = unnormalized_goal else: env.unwrapped.goal = goal rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=observation_key, context_keys_for_policy=[context_key], reset_callback=set_goal_for_visualization, ) else: video_renderer = EnvRenderer(**video_renderer_kwargs) rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=observation_key, context_keys_for_policy=[context_key], reset_callback=None, ) renderers = OrderedDict(image_observation=video_renderer, ) state_env = expl_env.env state_space = state_env.observation_space[observation_key] low = state_space.low.min() high = state_space.high.max() y = np.linspace(low, high, num=video_renderer.image_chw[1]) x = np.linspace(low, high, num=video_renderer.image_chw[2]) all_xy_np = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))]) all_xy_torch = ptu.from_numpy(all_xy_np) num_states = all_xy_torch.shape[0] if visualize_dynamics: def create_dynamics_visualizer(show_prob, vary_state=False): def get_prob(obs_dict, action): obs = obs_dict['state_observation'] obs_torch = ptu.from_numpy(obs)[None] action_torch = ptu.from_numpy(action)[None] if vary_state: action_repeated = torch.zeros((num_states, 2)) dist = dynamics_model(all_xy_torch, action_repeated) goal = ptu.from_numpy( obs_dict['state_desired_goal'][None]) log_probs = dist.log_prob(goal) else: dist = dynamics_model(obs_torch, action_torch) log_probs = dist.log_prob(all_xy_torch) if show_prob: return log_probs.exp() else: return log_probs return get_prob renderers['log_prob'] = ValueRenderer( create_dynamics_visualizer(False), **video_renderer_kwargs) # renderers['prob'] = ValueRenderer( # create_dynamics_visualizer(True), **video_renderer_kwargs # ) renderers['log_prob_vary_state'] = ValueRenderer( create_dynamics_visualizer(False, vary_state=True), only_get_image_once_per_episode=True, max_out_walls=isinstance(stub_env, PickAndPlaceEnv), **video_renderer_kwargs) # renderers['prob_vary_state'] = ValueRenderer( # create_dynamics_visualizer(True, vary_state=True), # **video_renderer_kwargs) if visualize_discount_model and pgr_trainer.discount_model: def get_discount_values(obs, action): obs = obs['state_observation'] obs_torch = ptu.from_numpy(obs)[None] combined_obs = torch.cat([ obs_torch.repeat(num_states, 1), all_xy_torch, ], dim=1) action_torch = ptu.from_numpy(action)[None] action_repeated = action_torch.repeat(num_states, 1) return pgr_trainer.discount_model(combined_obs, action_repeated) renderers['discount_model'] = ValueRenderer( get_discount_values, states_to_eval=all_xy_torch, **video_renderer_kwargs) if 'log_prob' in renderers and 'discount_model' in renderers: renderers['log_prob_time_discount'] = ProductRenderer( renderers['discount_model'], renderers['log_prob'], **video_renderer_kwargs) def get_reward(obs_dict, action, next_obs_dict): o = batchify(obs_dict) a = batchify(action) next_o = batchify(next_obs_dict) reward = reward_fn(o, a, next_o, next_o) return reward[0] def get_bootstrap(obs_dict, action, next_obs_dict, return_float=True): context_pt = ptu.from_numpy(obs_dict[context_key][None]) o_pt = ptu.from_numpy(obs_dict[observation_key][None]) next_o_pt = ptu.from_numpy(next_obs_dict[observation_key][None]) action_torch = ptu.from_numpy(action[None]) bootstrap, *_ = pgr_trainer.get_bootstrap_stats( torch.cat((o_pt, context_pt), dim=1), action_torch, torch.cat((next_o_pt, context_pt), dim=1), ) if return_float: return ptu.get_numpy(bootstrap)[0, 0] else: return bootstrap def get_discount(obs_dict, action, next_obs_dict): bootstrap = get_bootstrap(obs_dict, action, next_obs_dict, return_float=False) reward_np = get_reward(obs_dict, action, next_obs_dict) reward = ptu.from_numpy(reward_np[None, None]) context_pt = ptu.from_numpy(obs_dict[context_key][None]) o_pt = ptu.from_numpy(obs_dict[observation_key][None]) obs = torch.cat((o_pt, context_pt), dim=1) actions = ptu.from_numpy(action[None]) discount = pgr_trainer.get_discount_factor( bootstrap, reward, obs, actions, ) if isinstance(discount, torch.Tensor): discount = ptu.get_numpy(discount)[0, 0] return np.clip(discount, a_min=1e-3, a_max=1) def create_modify_fn( title, set_params=None, scientific=True, ): def modify(ax): ax.set_title(title) if set_params: ax.set(**set_params) if scientific: scaler = ScalarFormatter(useOffset=True) scaler.set_powerlimits((1, 1)) ax.yaxis.set_major_formatter(scaler) ax.ticklabel_format(axis='y', style='sci') return modify def add_left_margin(fig): fig.subplots_adjust(left=0.2) if visualize_all_plots or plot_discount: renderers['discount'] = DynamicNumberEnvRenderer( dynamic_number_fn=get_discount, modify_ax_fn=create_modify_fn( title='discount', set_params=dict( # yscale='log', ylim=[-0.05, 1.1], ), # scientific=False, ), modify_fig_fn=add_left_margin, # autoscale_y=False, **plot_renderer_kwargs) if visualize_all_plots or plot_reward: renderers['reward'] = DynamicNumberEnvRenderer( dynamic_number_fn=get_reward, modify_ax_fn=create_modify_fn(title='reward', # scientific=False, ), modify_fig_fn=add_left_margin, **plot_renderer_kwargs) if visualize_all_plots or plot_bootstrap_value: renderers['bootstrap-value'] = DynamicNumberEnvRenderer( dynamic_number_fn=get_bootstrap, modify_ax_fn=create_modify_fn(title='bootstrap value', # scientific=False, ), modify_fig_fn=add_left_margin, **plot_renderer_kwargs) def add_images(env, state_distribution): state_env = env.env if is_gym_env: goal_distribution = state_distribution else: goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=video_renderer, ) context_env = ContextualEnv( state_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, update_env_info_fn=delete_info, ) return InsertDebugImagesEnv( context_env, renderers=renderers, ) img_expl_env = add_images(expl_env, expl_context_distrib) if is_gym_env: imgs_to_show = list(renderers.keys()) else: imgs_to_show = ['image_desired_goal'] + list(renderers.keys()) img_formats = [video_renderer.output_image_format] img_formats += [r.output_image_format for r in renderers.values()] expl_video_func = get_save_video_function( rollout_function, img_expl_env, exploration_policy, tag="xplor", imsize=video_renderer.image_chw[1], image_formats=img_formats, keys_to_show=imgs_to_show, **save_video_kwargs) algorithm.post_train_funcs.append(expl_video_func) for eval_env_name, (env, context_distrib) in ( eval_env_name_to_env_and_context_distrib.items()): img_eval_env = add_images(env, context_distrib) eval_video_func = get_save_video_function( rollout_function, img_eval_env, eval_policy, tag=eval_env_name, imsize=video_renderer.image_chw[1], image_formats=img_formats, keys_to_show=imgs_to_show, **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.train()
def awac_rig_experiment( max_path_length, qf_kwargs, trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, train_vae_kwargs, policy_class=TanhGaussianPolicy, env_id=None, env_class=None, env_kwargs=None, reward_kwargs=None, observation_key='latent_observation', desired_goal_key='latent_desired_goal', state_observation_key='state_observation', state_goal_key='state_desired_goal', image_goal_key='image_desired_goal', path_loader_class=MDPPathLoader, demo_replay_buffer_kwargs=None, path_loader_kwargs=None, env_demo_path='', env_offpolicy_data_path='', debug=False, epsilon=1.0, exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, add_env_demos=False, add_env_offpolicy_data=False, save_paths=False, load_demos=False, pretrain_policy=False, pretrain_rl=False, save_pretrained_algorithm=False, # Video parameters save_video=True, save_video_kwargs=None, renderer_kwargs=None, imsize=84, pretrained_vae_path="", presampled_goals_path="", init_camera=None, qf_class=ConcatMlp, ): #Kwarg Definitions if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if demo_replay_buffer_kwargs is None: demo_replay_buffer_kwargs = {} if path_loader_kwargs is None: path_loader_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not renderer_kwargs: renderer_kwargs = {} if debug: max_path_length = 5 algo_kwargs['batch_size'] = 5 algo_kwargs['num_epochs'] = 5 algo_kwargs['num_eval_steps_per_epoch'] = 100 algo_kwargs['num_expl_steps_per_train_loop'] = 100 algo_kwargs['num_trains_per_train_loop'] = 10 algo_kwargs['min_num_steps_before_training'] = 100 algo_kwargs['min_num_steps_before_training'] = 100 trainer_kwargs['bc_num_pretrain_steps'] = min( 10, trainer_kwargs.get('bc_num_pretrain_steps', 0)) trainer_kwargs['q_num_pretrain1_steps'] = min( 10, trainer_kwargs.get('q_num_pretrain1_steps', 0)) trainer_kwargs['q_num_pretrain2_steps'] = min( 10, trainer_kwargs.get('q_num_pretrain2_steps', 0)) #Enviorment Wrapping renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode, presampled_goals_path): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) # encoded_env = EncoderWrappedEnv( # img_env, # model, # dict(image_observation="latent_observation", ), # ) # if goal_sampling_mode == "vae_prior": # latent_goal_distribution = PriorDistribution( # model.representation_size, # desired_goal_key, # ) # diagnostics = StateImageGoalDiagnosticsFn({}, ) # elif goal_sampling_mode == "presampled": # diagnostics = state_env.get_contextual_diagnostics # image_goal_distribution = PresampledPathDistribution( # presampled_goals_path, # ) # latent_goal_distribution = AddLatentDistribution( # image_goal_distribution, # image_goal_key, # desired_goal_key, # model, # ) # elif goal_sampling_mode == "reset_of_env": # state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) # state_goal_distribution = GoalDictDistributionFromMultitaskEnv( # state_goal_env, # desired_goal_keys=[state_goal_key], # ) # image_goal_distribution = AddImageDistribution( # env=state_env, # base_distribution=state_goal_distribution, # image_goal_key=image_goal_key, # renderer=renderer, # ) # latent_goal_distribution = AddLatentDistribution( # image_goal_distribution, # image_goal_key, # desired_goal_key, # model, # ) # no_goal_distribution = PriorDistribution( # representation_size=0, # key="no_goal", # ) # diagnostics = state_goal_env.get_contextual_diagnostics # else: # error diagnostics = StateImageGoalDiagnosticsFn({}, ) no_goal_distribution = PriorDistribution( representation_size=0, key="no_goal", ) reward_fn = GraspingRewardFn( # img_env, # state_env, # observation_key=observation_key, # desired_goal_key=desired_goal_key, # **reward_kwargs ) env = ContextualEnv( img_env, # state_env, context_distribution=no_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, no_goal_distribution, reward_fn #VAE Setup if pretrained_vae_path: model = load_local_or_remote_file(pretrained_vae_path) else: model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class, imsize, init_camera) path_loader_kwargs['model_path'] = pretrained_vae_path #Enviorment Definitions expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, exploration_goal_sampling_mode, presampled_goals_path) eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, evaluation_goal_sampling_mode, presampled_goals_path) path_loader_kwargs['env'] = eval_env #AWAC Code if add_env_demos: path_loader_kwargs["demo_paths"].append(env_demo_path) if add_env_offpolicy_data: path_loader_kwargs["demo_paths"].append(env_offpolicy_data_path) #Key Setting context_key = desired_goal_key obs_dim = (expl_env.observation_space.spaces[observation_key].low.size + expl_env.observation_space.spaces[context_key].low.size) action_dim = expl_env.action_space.low.size state_rewards = reward_kwargs.get('reward_type', 'dense') == 'wrapped_env' # if state_rewards: # mapper = RemapKeyFn({context_key: observation_key, state_goal_key: state_observation_key}) # obs_keys = [state_observation_key, observation_key] # cont_keys = [state_goal_key, context_key] # else: mapper = RemapKeyFn({context_key: observation_key}) obs_keys = [observation_key] cont_keys = [context_key] #Replay Buffer def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict, new_contexts): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[context_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=cont_keys, observation_keys=obs_keys, observation_key=observation_key, context_distribution=expl_context_distrib, sample_context_from_obs_dict_fn=mapper, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) replay_buffer_kwargs.update(demo_replay_buffer_kwargs) demo_train_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=cont_keys, observation_keys=obs_keys, observation_key=observation_key, context_distribution=expl_context_distrib, sample_context_from_obs_dict_fn=mapper, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) demo_test_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=cont_keys, observation_keys=obs_keys, observation_key=observation_key, context_distribution=expl_context_distrib, sample_context_from_obs_dict_fn=mapper, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) #Neural Network Architecture def create_qf(): # return ConcatMlp( # input_size=obs_dim + action_dim, # output_size=1, # **qf_kwargs # ) if qf_class is ConcatMlp: qf_kwargs["input_size"] = obs_dim + action_dim if qf_class is ConcatCNN: qf_kwargs["added_fc_input_size"] = action_dim return qf_class(output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() policy = policy_class( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs, ) #Path Collectors eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=observation_key, context_keys_for_policy=[ context_key, ], ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=observation_key, context_keys_for_policy=[ context_key, ], ) #Algorithm trainer = AWACTrainer(env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **trainer_kwargs) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) #Video Saving if save_video: expl_video_func = RIGVideoSaveFunction( model, expl_path_collector, "train", # decode_goal_image_key="image_decoded_goal", # reconstruction_key="image_reconstruction", rows=2, columns=5, unnormalize=True, imsize=imsize, image_format=renderer.output_image_format, **save_video_kwargs) algorithm.post_train_funcs.append(expl_video_func) eval_video_func = RIGVideoSaveFunction( model, eval_path_collector, "eval", # goal_image_key=image_goal_key, # decode_goal_image_key="image_decoded_goal", # reconstruction_key="image_reconstruction", num_imgs=4, rows=2, columns=5, unnormalize=True, imsize=imsize, image_format=renderer.output_image_format, **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) #AWAC CODE if save_paths: algorithm.post_train_funcs.append(save_paths) if load_demos: path_loader = path_loader_class( trainer, replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, # reward_fn=eval_reward, # omit reward because its recomputed later **path_loader_kwargs) path_loader.load_demos() if pretrain_policy: trainer.pretrain_policy_with_bc( policy, demo_train_buffer, demo_test_buffer, trainer.bc_num_pretrain_steps, ) if pretrain_rl: trainer.pretrain_q_with_bc_data() if save_pretrained_algorithm: p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p') pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt') data = algorithm._get_snapshot() data['algorithm'] = algorithm torch.save(data, open(pt_path, "wb")) torch.save(data, open(p_path, "wb")) algorithm.train()
def rl_context_experiment(variant): import rlkit.torch.pytorch_util as ptu from rlkit.torch.td3.td3 import TD3 as TD3Trainer from rlkit.torch.sac.sac import SACTrainer from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.sac.policies import MakeDeterministic preprocess_rl_variant(variant) max_path_length = variant['max_path_length'] observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = variant.get('achieved_goal_key', 'latent_achieved_goal') contextual_mdp = variant.get('contextual_mdp', True) print("contextual_mdp:", contextual_mdp) mask_variant = variant.get('mask_variant', {}) mask_conditioned = mask_variant.get('mask_conditioned', False) print("mask_conditioned:", mask_conditioned) if mask_conditioned: assert contextual_mdp if 'sac' in variant['algorithm'].lower(): rl_algo = 'sac' elif 'td3' in variant['algorithm'].lower(): rl_algo = 'td3' else: raise NotImplementedError print("RL algorithm:", rl_algo) ### load the example dataset, if running checkpoints ### if 'ckpt' in variant: import os.path as osp example_set_variant = variant.get('example_set_variant', dict()) example_set_variant['use_cache'] = True example_set_variant['cache_path'] = osp.join(variant['ckpt'], 'example_dataset.npy') if mask_conditioned: env = get_envs(variant) mask_format = mask_variant['param_variant']['mask_format'] assert mask_format in ['vector', 'matrix', 'distribution', 'cond_distribution'] goal_dim = env.observation_space.spaces[desired_goal_key].low.size if mask_format in ['vector']: context_dim_for_networks = goal_dim + goal_dim elif mask_format in ['matrix', 'distribution', 'cond_distribution']: context_dim_for_networks = goal_dim + (goal_dim * goal_dim) else: raise TypeError if 'ckpt' in variant: from rlkit.misc.asset_loader import local_path_from_s3_or_local_path import os.path as osp filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'masks.npy')) masks = np.load(filename, allow_pickle=True)[()] else: masks = get_mask_params( env=env, example_set_variant=variant['example_set_variant'], param_variant=mask_variant['param_variant'], ) mask_keys = list(masks.keys()) context_keys = [desired_goal_key] + mask_keys else: context_keys = [desired_goal_key] def contextual_env_distrib_and_reward(mode='expl'): assert mode in ['expl', 'eval'] env = get_envs(variant) if mode == 'expl': goal_sampling_mode = variant.get('expl_goal_sampling_mode', None) elif mode == 'eval': goal_sampling_mode = variant.get('eval_goal_sampling_mode', None) if goal_sampling_mode not in [None, 'example_set']: env.goal_sampling_mode = goal_sampling_mode mask_ids_for_training = mask_variant.get('mask_ids_for_training', None) if mask_conditioned: context_distrib = MaskDictDistribution( env, desired_goal_keys=[desired_goal_key], mask_format=mask_format, masks=masks, max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None), prev_subtask_weight=mask_variant.get('prev_subtask_weight', None), mask_distr=mask_variant.get('train_mask_distr', None), mask_ids=mask_ids_for_training, ) reward_fn = ContextualMaskingRewardFn( achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, mask_keys=mask_keys, mask_format=mask_format, use_g_for_mean=mask_variant['use_g_for_mean'], use_squared_reward=mask_variant.get('use_squared_reward', False), ) else: if goal_sampling_mode == 'example_set': example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant']) assert len(example_dataset['list_of_waypoints']) == 1 from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet context_distrib = GoalDictDistributionFromSet( example_dataset['list_of_waypoints'][0], desired_goal_keys=[desired_goal_key], ) else: context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None), ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=context_distrib, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None, ) return env, context_distrib, reward_fn env, context_distrib, reward_fn = contextual_env_distrib_and_reward(mode='expl') eval_env, eval_context_distrib, _ = contextual_env_distrib_and_reward(mode='eval') if mask_conditioned: obs_dim = ( env.observation_space.spaces[observation_key].low.size + context_dim_for_networks ) elif contextual_mdp: obs_dim = ( env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size ) else: obs_dim = env.observation_space.spaces[observation_key].low.size action_dim = env.action_space.low.size if 'ckpt' in variant and 'ckpt_epoch' in variant: from rlkit.misc.asset_loader import local_path_from_s3_or_local_path import os.path as osp ckpt_epoch = variant['ckpt_epoch'] if ckpt_epoch is not None: epoch = variant['ckpt_epoch'] filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch)) else: filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'params.pkl')) print("Loading ckpt from", filename) data = torch.load(filename) qf1 = data['trainer/qf1'] qf2 = data['trainer/qf2'] target_qf1 = data['trainer/target_qf1'] target_qf2 = data['trainer/target_qf2'] policy = data['trainer/policy'] eval_policy = data['evaluation/policy'] expl_policy = data['exploration/policy'] else: qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'] ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'] ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'] ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'] ) if rl_algo == 'td3': policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'] ) target_policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'] ) expl_policy = create_exploration_policy( env, policy, exploration_version=variant['exploration_type'], exploration_noise=variant['exploration_noise'], ) eval_policy = policy elif rl_algo == 'sac': policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, **variant['policy_kwargs'] ) expl_policy = policy eval_policy = MakeDeterministic(policy) post_process_mask_fn = partial( full_post_process_mask_fn, mask_conditioned=mask_conditioned, mask_variant=mask_variant, context_distrib=context_distrib, context_key=desired_goal_key, achieved_goal_key=achieved_goal_key, ) def context_from_obs_dict_fn(obs_dict): context_dict = { desired_goal_key: obs_dict[achieved_goal_key] } if mask_conditioned: sample_masks_for_relabeling = mask_variant.get('sample_masks_for_relabeling', True) if sample_masks_for_relabeling: batch_size = next(iter(obs_dict.values())).shape[0] sampled_contexts = context_distrib.sample(batch_size) for mask_key in mask_keys: context_dict[mask_key] = sampled_contexts[mask_key] else: for mask_key in mask_keys: context_dict[mask_key] = obs_dict[mask_key] return context_dict def concat_context_to_obs(batch, replay_buffer=None, obs_dict=None, next_obs_dict=None, new_contexts=None): obs = batch['observations'] next_obs = batch['next_observations'] batch_size = obs.shape[0] if mask_conditioned: if obs_dict is not None and new_contexts is not None: if not mask_variant.get('relabel_masks', True): for k in mask_keys: new_contexts[k] = next_obs_dict[k][:] batch.update(new_contexts) if not mask_variant.get('relabel_goals', True): new_contexts[desired_goal_key] = next_obs_dict[desired_goal_key][:] batch.update(new_contexts) new_contexts = post_process_mask_fn(obs_dict, new_contexts) batch.update(new_contexts) if mask_format in ['vector', 'matrix']: goal = batch[desired_goal_key] mask = batch['mask'].reshape((batch_size, -1)) batch['observations'] = np.concatenate([obs, goal, mask], axis=1) batch['next_observations'] = np.concatenate([next_obs, goal, mask], axis=1) elif mask_format == 'distribution': goal = batch[desired_goal_key] sigma_inv = batch['mask_sigma_inv'].reshape((batch_size, -1)) batch['observations'] = np.concatenate([obs, goal, sigma_inv], axis=1) batch['next_observations'] = np.concatenate([next_obs, goal, sigma_inv], axis=1) elif mask_format == 'cond_distribution': goal = batch[desired_goal_key] mu_w = batch['mask_mu_w'] mu_g = batch['mask_mu_g'] mu_A = batch['mask_mu_mat'] sigma_inv = batch['mask_sigma_inv'] if mask_variant['use_g_for_mean']: mu_w_given_g = goal else: mu_w_given_g = mu_w + np.squeeze(mu_A @ np.expand_dims(goal - mu_g, axis=-1), axis=-1) sigma_w_given_g_inv = sigma_inv.reshape((batch_size, -1)) batch['observations'] = np.concatenate([obs, mu_w_given_g, sigma_w_given_g_inv], axis=1) batch['next_observations'] = np.concatenate([next_obs, mu_w_given_g, sigma_w_given_g_inv], axis=1) else: raise NotImplementedError elif contextual_mdp: goal = batch[desired_goal_key] batch['observations'] = np.concatenate([obs, goal], axis=1) batch['next_observations'] = np.concatenate([next_obs, goal], axis=1) else: batch['observations'] = obs batch['next_observations'] = next_obs return batch if 'observation_keys' not in variant['contextual_replay_buffer_kwargs']: variant['contextual_replay_buffer_kwargs']['observation_keys'] = [] observation_keys = variant['contextual_replay_buffer_kwargs']['observation_keys'] if observation_key not in observation_keys: observation_keys.append(observation_key) if achieved_goal_key not in observation_keys: observation_keys.append(achieved_goal_key) replay_buffer = ContextualRelabelingReplayBuffer( env=env, context_keys=context_keys, context_distribution=context_distrib, sample_context_from_obs_dict_fn=context_from_obs_dict_fn, reward_fn=reward_fn, post_process_batch_fn=concat_context_to_obs, **variant['contextual_replay_buffer_kwargs'] ) if rl_algo == 'td3': trainer = TD3Trainer( policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, target_policy=target_policy, **variant['td3_trainer_kwargs'] ) elif rl_algo == 'sac': trainer = SACTrainer( env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['sac_trainer_kwargs'] ) def create_path_collector( env, policy, mode='expl', mask_kwargs={}, ): assert mode in ['expl', 'eval'] save_env_in_snapshot = variant.get('save_env_in_snapshot', True) if mask_conditioned: if 'rollout_mask_order' in mask_kwargs: rollout_mask_order = mask_kwargs['rollout_mask_order'] else: if mode == 'expl': rollout_mask_order = mask_variant.get('rollout_mask_order_for_expl', 'fixed') elif mode == 'eval': rollout_mask_order = mask_variant.get('rollout_mask_order_for_eval', 'fixed') else: raise TypeError if 'mask_distr' in mask_kwargs: mask_distr = mask_kwargs['mask_distr'] else: if mode == 'expl': mask_distr = mask_variant['expl_mask_distr'] elif mode == 'eval': mask_distr = mask_variant['eval_mask_distr'] else: raise TypeError if 'mask_ids' in mask_kwargs: mask_ids = mask_kwargs['mask_ids'] else: if mode == 'expl': mask_ids = mask_variant.get('mask_ids_for_expl', None) elif mode == 'eval': mask_ids = mask_variant.get('mask_ids_for_eval', None) else: raise TypeError prev_subtask_weight = mask_variant.get('prev_subtask_weight', None) max_subtasks_to_focus_on = mask_variant.get('max_subtasks_to_focus_on', None) max_subtasks_per_rollout = mask_variant.get('max_subtasks_per_rollout', None) mode = mask_variant.get('context_post_process_mode', None) if mode in ['dilute_prev_subtasks_uniform', 'dilute_prev_subtasks_fixed']: prev_subtask_weight = 0.5 return MaskPathCollector( env, policy, observation_key=observation_key, context_keys_for_policy=context_keys, concat_context_to_obs_fn=concat_context_to_obs, save_env_in_snapshot=save_env_in_snapshot, mask_sampler=(context_distrib if mode=='expl' else eval_context_distrib), mask_distr=mask_distr.copy(), mask_ids=mask_ids, max_path_length=max_path_length, rollout_mask_order=rollout_mask_order, prev_subtask_weight=prev_subtask_weight, max_subtasks_to_focus_on=max_subtasks_to_focus_on, max_subtasks_per_rollout=max_subtasks_per_rollout, ) elif contextual_mdp: return ContextualPathCollector( env, policy, observation_key=observation_key, context_keys_for_policy=context_keys, save_env_in_snapshot=save_env_in_snapshot, ) else: return ContextualPathCollector( env, policy, observation_key=observation_key, context_keys_for_policy=[], save_env_in_snapshot=save_env_in_snapshot, ) expl_path_collector = create_path_collector(env, expl_policy, mode='expl') eval_path_collector = create_path_collector(eval_env, eval_policy, mode='eval') algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **variant['algo_kwargs'] ) algorithm.to(ptu.device) if variant.get("save_video", True): save_period = variant.get('save_video_period', 50) dump_video_kwargs = variant.get("dump_video_kwargs", dict()) dump_video_kwargs['horizon'] = max_path_length renderer = EnvRenderer(**variant.get('renderer_kwargs', {})) def add_images(env, state_distribution): state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=renderer, ) img_env = InsertImagesEnv(state_env, renderers={ 'image_observation' : renderer, }) context_env = ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, update_env_info_fn=delete_info, ) return context_env img_eval_env = add_images(eval_env, eval_context_distrib) if variant.get('log_eval_video', True): video_path_collector = create_path_collector(img_eval_env, eval_policy, mode='eval') rollout_function = video_path_collector._rollout_fn eval_video_func = get_save_video_function( rollout_function, img_eval_env, eval_policy, tag="eval", imsize=variant['renderer_kwargs']['width'], image_format='CHW', save_video_period=save_period, **dump_video_kwargs ) algorithm.post_train_funcs.append(eval_video_func) # additional eval videos for mask conditioned case if mask_conditioned: default_list = [ 'atomic', 'atomic_seq', 'cumul_seq', 'full', ] eval_rollouts_for_videos = mask_variant.get('eval_rollouts_for_videos', default_list) for key in eval_rollouts_for_videos: assert key in default_list if 'cumul_seq' in eval_rollouts_for_videos: video_path_collector = create_path_collector( img_eval_env, eval_policy, mode='eval', mask_kwargs=dict( mask_distr=dict( cumul_seq=1.0 ), ), ) rollout_function = video_path_collector._rollout_fn eval_video_func = get_save_video_function( rollout_function, img_eval_env, eval_policy, tag="eval_cumul" if mask_conditioned else "eval", imsize=variant['renderer_kwargs']['width'], image_format='HWC', save_video_period=save_period, **dump_video_kwargs ) algorithm.post_train_funcs.append(eval_video_func) if 'full' in eval_rollouts_for_videos: video_path_collector = create_path_collector( img_eval_env, eval_policy, mode='eval', mask_kwargs=dict( mask_distr=dict( full=1.0 ), ), ) rollout_function = video_path_collector._rollout_fn eval_video_func = get_save_video_function( rollout_function, img_eval_env, eval_policy, tag="eval_full", imsize=variant['renderer_kwargs']['width'], image_format='HWC', save_video_period=save_period, **dump_video_kwargs ) algorithm.post_train_funcs.append(eval_video_func) if 'atomic_seq' in eval_rollouts_for_videos: video_path_collector = create_path_collector( img_eval_env, eval_policy, mode='eval', mask_kwargs=dict( mask_distr=dict( atomic_seq=1.0 ), ), ) rollout_function = video_path_collector._rollout_fn eval_video_func = get_save_video_function( rollout_function, img_eval_env, eval_policy, tag="eval_atomic", imsize=variant['renderer_kwargs']['width'], image_format='HWC', save_video_period=save_period, **dump_video_kwargs ) algorithm.post_train_funcs.append(eval_video_func) if variant.get('log_expl_video', True) and not variant['algo_kwargs'].get('eval_only', False): img_expl_env = add_images(env, context_distrib) video_path_collector = create_path_collector(img_expl_env, expl_policy, mode='expl') rollout_function = video_path_collector._rollout_fn expl_video_func = get_save_video_function( rollout_function, img_expl_env, expl_policy, tag="expl", imsize=variant['renderer_kwargs']['width'], image_format='CHW', save_video_period=save_period, **dump_video_kwargs ) algorithm.post_train_funcs.append(expl_video_func) addl_collectors = [] addl_log_prefixes = [] if mask_conditioned and mask_variant.get('log_mask_diagnostics', True): default_list = [ 'atomic', 'atomic_seq', 'cumul_seq', 'full', ] eval_rollouts_to_log = mask_variant.get('eval_rollouts_to_log', default_list) for key in eval_rollouts_to_log: assert key in default_list # atomic masks if 'atomic' in eval_rollouts_to_log: for mask_id in eval_path_collector.mask_ids: mask_kwargs=dict( mask_ids=[mask_id], mask_distr=dict( atomic=1.0, ), ) collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs) addl_collectors.append(collector) addl_log_prefixes += [ 'mask_{}/'.format(''.join(str(mask_id))) for mask_id in eval_path_collector.mask_ids ] # full mask if 'full' in eval_rollouts_to_log: mask_kwargs=dict( mask_distr=dict( full=1.0, ), ) collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs) addl_collectors.append(collector) addl_log_prefixes.append('mask_full/') # cumulative, sequential mask if 'cumul_seq' in eval_rollouts_to_log: mask_kwargs=dict( rollout_mask_order='fixed', mask_distr=dict( cumul_seq=1.0, ), ) collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs) addl_collectors.append(collector) addl_log_prefixes.append('mask_cumul_seq/') # atomic, sequential mask if 'atomic_seq' in eval_rollouts_to_log: mask_kwargs=dict( rollout_mask_order='fixed', mask_distr=dict( atomic_seq=1.0, ), ) collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs) addl_collectors.append(collector) addl_log_prefixes.append('mask_atomic_seq/') def get_mask_diagnostics(unused): from rlkit.core.logging import append_log, add_prefix, OrderedDict log = OrderedDict() for prefix, collector in zip(addl_log_prefixes, addl_collectors): paths = collector.collect_new_paths( max_path_length, variant['algo_kwargs']['num_eval_steps_per_epoch'], discard_incomplete_paths=True, ) old_path_info = eval_env.get_diagnostics(paths) keys_to_keep = [] for key in old_path_info.keys(): if ('env_infos' in key) and ('final' in key) and ('Mean' in key): keys_to_keep.append(key) path_info = OrderedDict() for key in keys_to_keep: path_info[key] = old_path_info[key] generic_info = add_prefix( path_info, prefix, ) append_log(log, generic_info) for collector in addl_collectors: collector.end_epoch(0) return log algorithm._eval_get_diag_fns.append(get_mask_diagnostics) if 'ckpt' in variant: from rlkit.misc.asset_loader import local_path_from_s3_or_local_path import os.path as osp assert variant['algo_kwargs'].get('eval_only', False) def update_networks(algo, epoch): if 'ckpt_epoch' in variant: return if epoch % algo._eval_epoch_freq == 0: filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch)) print("Loading ckpt from", filename) data = torch.load(filename)#, map_location='cuda:1') eval_policy = data['evaluation/policy'] eval_policy.to(ptu.device) algo.eval_data_collector._policy = eval_policy for collector in addl_collectors: collector._policy = eval_policy algorithm.post_train_funcs.insert(0, update_networks) algorithm.train()
def disco_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, generate_set_for_rl_kwargs, # VAE parameters create_vae_kwargs, vae_trainer_kwargs, vae_algo_kwargs, data_loader_kwargs, generate_set_for_vae_pretraining_kwargs, num_ungrouped_images, beta_schedule_kwargs=None, # Oracle settings use_ground_truth_reward=False, use_onehot_set_embedding=False, use_dummy_model=False, observation_key="latent_observation", # RIG comparison rig_goal_setter_kwargs=None, rig=False, # Miscellaneous reward_fn_kwargs=None, # None-VAE Params env_id=None, env_class=None, env_kwargs=None, latent_observation_key="latent_observation", state_observation_key="state_observation", image_observation_key="image_observation", set_description_key="set_description", example_state_key="example_state", example_image_key="example_image", # Exploration exploration_policy_kwargs=None, # Video parameters save_video=True, save_video_kwargs=None, renderer_kwargs=None, ): if rig_goal_setter_kwargs is None: rig_goal_setter_kwargs = {} if reward_fn_kwargs is None: reward_fn_kwargs = {} if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not renderer_kwargs: renderer_kwargs = {} renderer = EnvRenderer(**renderer_kwargs) sets = create_sets( env_id, env_class, env_kwargs, renderer, example_state_key=example_state_key, example_image_key=example_image_key, **generate_set_for_rl_kwargs, ) if use_dummy_model: model = create_dummy_image_vae(img_chw=renderer.image_chw, **create_vae_kwargs) else: model = train_set_vae( create_vae_kwargs, vae_trainer_kwargs, vae_algo_kwargs, data_loader_kwargs, generate_set_for_vae_pretraining_kwargs, num_ungrouped_images, env_id=env_id, env_class=env_class, env_kwargs=env_kwargs, beta_schedule_kwargs=beta_schedule_kwargs, sets=sets, renderer=renderer, ) expl_env, expl_context_distrib, expl_reward = ( contextual_env_distrib_and_reward( vae=model, sets=sets, state_env=get_gym_env( env_id, env_class=env_class, env_kwargs=env_kwargs, ), renderer=renderer, reward_fn_kwargs=reward_fn_kwargs, use_ground_truth_reward=use_ground_truth_reward, state_observation_key=state_observation_key, latent_observation_key=latent_observation_key, example_image_key=example_image_key, set_description_key=set_description_key, observation_key=observation_key, image_observation_key=image_observation_key, rig_goal_setter_kwargs=rig_goal_setter_kwargs, )) eval_env, eval_context_distrib, eval_reward = ( contextual_env_distrib_and_reward( vae=model, sets=sets, state_env=get_gym_env( env_id, env_class=env_class, env_kwargs=env_kwargs, ), renderer=renderer, reward_fn_kwargs=reward_fn_kwargs, use_ground_truth_reward=use_ground_truth_reward, state_observation_key=state_observation_key, latent_observation_key=latent_observation_key, example_image_key=example_image_key, set_description_key=set_description_key, observation_key=observation_key, image_observation_key=image_observation_key, rig_goal_setter_kwargs=rig_goal_setter_kwargs, oracle_rig_goal=rig, )) context_keys = [ expl_context_distrib.mean_key, expl_context_distrib.covariance_key, expl_context_distrib.set_index_key, expl_context_distrib.set_embedding_key, ] if rig: context_keys_for_rl = [ expl_context_distrib.mean_key, ] else: if use_onehot_set_embedding: context_keys_for_rl = [ expl_context_distrib.set_embedding_key, ] else: context_keys_for_rl = [ expl_context_distrib.mean_key, expl_context_distrib.covariance_key, ] obs_dim = np.prod(expl_env.observation_space.spaces[observation_key].shape) obs_dim += sum([ np.prod(expl_env.observation_space.spaces[k].shape) for k in context_keys_for_rl ]) action_dim = np.prod(expl_env.action_space.shape) def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs) def concat_context_to_obs(batch, *args, **kwargs): obs = batch["observations"] next_obs = batch["next_observations"] contexts = [batch[k] for k in context_keys_for_rl] batch["observations"] = np.concatenate((obs, *contexts), axis=1) batch["next_observations"] = np.concatenate( (next_obs, *contexts), axis=1, ) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=context_keys, observation_keys=list( {observation_key, state_observation_key, latent_observation_key}), observation_key=observation_key, context_distribution=FilterKeys( expl_context_distrib, context_keys, ), sample_context_from_obs_dict_fn=None, # RemapKeyFn({context_key: observation_key}), reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs, ) trainer = SACTrainer( env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs, ) eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=observation_key, context_keys_for_policy=context_keys_for_rl, ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=observation_key, context_keys_for_policy=context_keys_for_rl, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs, ) algorithm.to(ptu.device) if save_video: set_index_key = eval_context_distrib.set_index_key expl_video_func = DisCoVideoSaveFunction( model, sets, expl_path_collector, tag="train", reconstruction_key="image_reconstruction", decode_set_image_key="decoded_set_prior", set_visualization_key="set_visualization", example_image_key=example_image_key, set_index_key=set_index_key, columns=len(sets), unnormalize=True, imsize=48, image_format=renderer.output_image_format, **save_video_kwargs, ) algorithm.post_train_funcs.append(expl_video_func) eval_video_func = DisCoVideoSaveFunction( model, sets, eval_path_collector, tag="eval", reconstruction_key="image_reconstruction", decode_set_image_key="decoded_set_prior", set_visualization_key="set_visualization", example_image_key=example_image_key, set_index_key=set_index_key, columns=len(sets), unnormalize=True, imsize=48, image_format=renderer.output_image_format, **save_video_kwargs, ) algorithm.post_train_funcs.append(eval_video_func) algorithm.train()
def masking_sac_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, # Video parameters save_video=True, save_video_kwargs=None, renderer_kwargs=None, train_env_id=None, eval_env_id=None, do_masking=True, mask_key="masked_observation", masking_eval_steps=200, log_mask_diagnostics=True, mask_dim=None, masking_reward_fn=None, masking_for_exploration=True, rotate_masks_for_eval=False, rotate_masks_for_expl=False, mask_distribution=None, num_steps_per_mask_change=10, tag=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not renderer_kwargs: renderer_kwargs = {} context_key = desired_goal_key env = get_gym_env(train_env_id, env_class=env_class, env_kwargs=env_kwargs) mask_dim = (mask_dim or env.observation_space.spaces[context_key].low.size) if not do_masking: mask_distribution = 'all_ones' assert mask_distribution in [ 'one_hot_masks', 'random_bit_masks', 'all_ones' ] if mask_distribution == 'all_ones': mask_distribution = 'static_mask' def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode, env_mask_distribution_type, static_mask=None, ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode use_static_mask = (not do_masking or env_mask_distribution_type == 'static_mask') # Default to all ones mask if static mask isn't defined and should # be using static masks. if use_static_mask and static_mask is None: static_mask = np.ones(mask_dim) if not do_masking: assert env_mask_distribution_type == 'static_mask' goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=mask_dim, distribution_type=env_mask_distribution_type, static_mask=static_mask, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, achieved_goal_from_observation=IndexIntoAchievedGoal( observation_key), ) if do_masking: if masking_reward_fn: reward_fn = partial(masking_reward_fn, context_key=context_key, mask_key=mask_key) else: reward_fn = partial(default_masked_reward_fn, context_key=context_key, mask_key=mask_key) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[state_diag_fn], ) return env, goal_distribution, reward_fn expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward( train_env_id, env_class, env_kwargs, exploration_goal_sampling_mode, mask_distribution if masking_for_exploration else 'static_mask', ) eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward( eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode, 'static_mask') # Distribution for relabeling relabel_context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) relabel_context_distrib = MaskedGoalDictDistribution( relabel_context_distrib, mask_key=mask_key, mask_dim=mask_dim, distribution_type=mask_distribution, static_mask=None if do_masking else np.ones(mask_dim), ) obs_dim = (expl_env.observation_space.spaces[observation_key].low.size + expl_env.observation_space.spaces[context_key].low.size + mask_dim) action_dim = expl_env.action_space.low.size def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs) def context_from_obs_dict_fn(obs_dict): achieved_goal = obs_dict['state_achieved_goal'] # Should the mask be randomized for future relabeling? # batch_size = len(achieved_goal) # mask = np.random.choice(2, size=(batch_size, mask_dim)) mask = obs_dict[mask_key] return { mask_key: mask, context_key: achieved_goal, } def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[context_key] mask = batch[mask_key] batch['observations'] = np.concatenate([obs, context, mask], axis=1) batch['next_observations'] = np.concatenate([next_obs, context, mask], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[context_key, mask_key], observation_keys=[observation_key, 'state_achieved_goal'], context_distribution=relabel_context_distrib, sample_context_from_obs_dict_fn=context_from_obs_dict_fn, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) trainer = SACTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) def create_path_collector(env, policy, is_rotating): if is_rotating: assert do_masking and mask_distribution == 'one_hot_masks' return RotatingMaskingPathCollector( env, policy, observation_key=observation_key, context_keys_for_policy=[context_key, mask_key], mask_key=mask_key, mask_length=mask_dim, num_steps_per_mask_change=num_steps_per_mask_change, ) else: return ContextualPathCollector( env, policy, observation_key=observation_key, context_keys_for_policy=[context_key, mask_key], ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) eval_path_collector = create_path_collector(eval_env, MakeDeterministic(policy), rotate_masks_for_eval) expl_path_collector = create_path_collector(expl_env, exploration_policy, rotate_masks_for_expl) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=observation_key, context_keys_for_policy=[context_key, mask_key], # Eval on everything for the base video obs_processor=lambda o: np.hstack( (o[observation_key], o[context_key], np.ones(mask_dim)))) renderer = EnvRenderer(**renderer_kwargs) def add_images(env, state_distribution): state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=renderer, ) img_env = InsertImageEnv(state_env, renderer=renderer) return ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=eval_reward, observation_key=observation_key, # update_env_info_fn=DeleteOldEnvInfo(), ) img_eval_env = add_images(eval_env, eval_context_distrib) img_expl_env = add_images(expl_env, expl_context_distrib) eval_video_func = get_save_video_function( rollout_function, img_eval_env, MakeDeterministic(policy), tag="eval", imsize=renderer.image_shape[0], image_format='CWH', **save_video_kwargs) expl_video_func = get_save_video_function( rollout_function, img_expl_env, exploration_policy, tag="train", imsize=renderer.image_shape[0], image_format='CWH', **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(expl_video_func) # For diagnostics, evaluate the policy on each individual dimension of the # mask. masks = [] collectors = [] for mask_idx in range(mask_dim): mask = np.zeros(mask_dim) mask[mask_idx] = 1 masks.append(mask) for mask in masks: for_dim_mask = mask if do_masking else np.ones(mask_dim) masked_env, _, _ = contextual_env_distrib_and_reward( eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode, 'static_mask', static_mask=for_dim_mask) collector = ContextualPathCollector( masked_env, MakeDeterministic(policy), observation_key=observation_key, context_keys_for_policy=[context_key, mask_key], ) collectors.append(collector) log_prefixes = [ 'mask_{}/'.format(''.join(mask.astype(int).astype(str))) for mask in masks ] def get_mask_diagnostics(unused): from rlkit.core.logging import append_log, add_prefix, OrderedDict from rlkit.misc import eval_util log = OrderedDict() for prefix, collector in zip(log_prefixes, collectors): paths = collector.collect_new_paths( max_path_length, masking_eval_steps, discard_incomplete_paths=True, ) generic_info = add_prefix( eval_util.get_generic_path_information(paths), prefix, ) append_log(log, generic_info) for collector in collectors: collector.end_epoch(0) return log if log_mask_diagnostics: algorithm._eval_get_diag_fns.append(get_mask_diagnostics) algorithm.train()