def add_images(env, base_distribution): if use_image_observations: img_env = env image_goal_distribution = base_distribution else: state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=base_distribution, image_goal_key='image_desired_goal', renderer=video_renderer, ) img_env = InsertImageEnv(state_env, renderer=video_renderer) img_env = InsertDebugImagesEnv( img_env, obj1_sweep_renderers, compute_shared_data=obj1_sweeper, ) img_env = InsertDebugImagesEnv( img_env, obj0_sweep_renderers, compute_shared_data=obj0_sweeper, ) return ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key_for_rl, update_env_info_fn=delete_info, )
def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn
def add_images(env, base_distribution): if use_image_observations: video_env = InsertImageEnv( env, renderer=video_renderer, image_key='video_observation', ) image_goal_distribution = base_distribution else: video_env = InsertImageEnv( env, renderer=video_renderer, image_key='image_observation', ) state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=base_distribution, image_goal_key='image_desired_goal', renderer=video_renderer, ) return ContextualEnv( video_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key_for_rl, update_env_info_fn=delete_info, )
def setup_env(state_env, encoder, reward_fn): goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) if use_image_observations: goal_distribution = AddImageDistribution( env=state_env, base_distribution=goal_distribution, image_goal_key=img_desired_goal_key, renderer=env_renderer, ) base_env = InsertImageEnv(state_env, renderer=env_renderer) goal_distribution = PresampledDistribution( goal_distribution, num_presampled_goals) goal_distribution = EncodedGoalDictDistribution( goal_distribution, encoder=encoder, keys_to_keep=[state_desired_goal_key, img_desired_goal_key], encoder_input_key=img_desired_goal_key, encoder_output_key=latent_desired_goal_key, ) else: base_env = state_env goal_distribution = EncodedGoalDictDistribution( goal_distribution, encoder=encoder, keys_to_keep=[state_desired_goal_key], encoder_input_key=state_desired_goal_key, encoder_output_key=latent_desired_goal_key, ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=latent_dim, distribution_type='one_hot_masks', ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) env = ContextualEnv( base_env, context_distribution=goal_distribution, reward_fn=reward_fn, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, **contextual_env_kwargs, ) return env, goal_distribution
def add_images(env, context_distribution): state_env = env.env img_env = InsertImageEnv( state_env, renderer=renderer, image_key='image_observation', ) return ContextualEnv( img_env, context_distribution=context_distribution, reward_fn=eval_reward, observation_key=observation_key, update_env_info_fn=delete_info, )
def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode, renderer): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_env.goal_sampling_mode = goal_sampling_mode state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=img_desired_goal_key, renderer=renderer, ) goal_distribution = PresampledDistribution(image_goal_distribution, 5000) img_env = InsertImageEnv(state_env, renderer=renderer) if reward_type == 'state_distance': reward_fn = ContextualRewardFnFromMultitaskEnv( env=state_env, achieved_goal_from_observation=IndexIntoAchievedGoal( 'state_observation'), desired_goal_key=state_desired_goal_key, achieved_goal_key=state_achieved_goal_key, ) elif reward_type == 'pixel_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( img_observation_key), desired_goal_key=img_desired_goal_key, ) else: raise ValueError(reward_type) env = ContextualEnv( img_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=img_observation_key, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn
def add_images(env, state_distribution): state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=renderer, ) img_env = InsertImageEnv(state_env, renderer=renderer) return ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=eval_reward, observation_key=observation_key, update_env_info_fn=delete_info, )
def get_video_func( env, policy, tag, ): renderer = EnvRenderer(**renderer_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) image_goal_distribution = AddImageDistribution( env=env, base_distribution=state_goal_distribution, image_goal_key="image_desired_goal", renderer=renderer, ) img_env = InsertImageEnv(env, renderer=renderer) rollout_function = partial( rf.multitask_rollout, max_path_length=variant["max_path_length"], observation_key=observation_key, desired_goal_key=desired_goal_key, return_dict_obs=True, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( observation_key), desired_goal_key=desired_goal_key, achieved_goal_key="state_achieved_goal", ) contextual_env = ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, ) video_func = get_save_video_function( rollout_function, contextual_env, policy, tag=tag, imsize=renderer.width, image_format="CWH", **save_video_kwargs, ) return video_func
def contextual_env_distrib_and_reward(mode='expl'): assert mode in ['expl', 'eval'] env = make(env_id, env_class, env_kwargs, normalize_env) env = GymToMultiEnv(env) # env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) no_goal_distribution = PriorDistribution( representation_size=0, key="no_goal", ) contextual_reward_fn = None env = ContextualEnv( env, context_distribution=no_goal_distribution, reward_fn=contextual_reward_fn, observation_key=observation_key, # contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=None, ) return env, no_goal_distribution, contextual_reward_fn
def add_images(env, state_distribution): state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=renderer, ) img_env = InsertImagesEnv(state_env, renderers={ 'image_observation': renderer, }) context_env = ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, update_env_info_fn=None, ) return context_env
def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode): env = get_gym_env( env_id, env_class=env_class, env_kwargs=env_kwargs, unwrap_timed_envs=True, ) env.goal_sampling_mode = goal_sampling_mode goal_distribution = GoalDictDistributionFromGymGoalEnv( env, desired_goal_key=desired_goal_key, ) distance_fn = L2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( achieved_goal_key, ), desired_goal_key=desired_goal_key, ) if (isinstance(env, robotics.FetchReachEnv) or isinstance(env, robotics.FetchPushEnv) or isinstance(env, robotics.FetchPickAndPlaceEnv) or isinstance(env, robotics.FetchSlideEnv)): success_threshold = 0.05 else: raise TypeError("I don't know the success threshold of env ", env) reward_fn = ThresholdDistanceReward(distance_fn, success_threshold) diag_fn = GenericGoalConditionedContextualDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn
def add_images(env, state_distribution): state_env = env.env if is_gym_env: goal_distribution = state_distribution else: goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=video_renderer, ) context_env = ContextualEnv( state_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, update_env_info_fn=delete_info, ) return InsertDebugImagesEnv( context_env, renderers=renderers, )
def image_based_goal_conditioned_sac_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, cnn_kwargs, policy_type='tanh-normal', env_id=None, env_class=None, env_kwargs=None, exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, reward_type='state_distance', env_renderer_kwargs=None, # Data augmentations apply_random_crops=False, random_crop_pixel_shift=4, # Video parameters save_video=True, save_video_kwargs=None, video_renderer_kwargs=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not env_renderer_kwargs: env_renderer_kwargs = {} if not video_renderer_kwargs: video_renderer_kwargs = {} img_observation_key = 'image_observation' img_desired_goal_key = 'image_desired_goal' state_observation_key = 'state_observation' state_desired_goal_key = 'state_desired_goal' state_achieved_goal_key = 'state_achieved_goal' sample_context_from_obs_dict_fn = RemapKeyFn({ 'image_desired_goal': 'image_observation', 'state_desired_goal': 'state_observation', }) def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode, renderer): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_env.goal_sampling_mode = goal_sampling_mode state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=img_desired_goal_key, renderer=renderer, ) goal_distribution = PresampledDistribution(image_goal_distribution, 5000) img_env = InsertImageEnv(state_env, renderer=renderer) if reward_type == 'state_distance': reward_fn = ContextualRewardFnFromMultitaskEnv( env=state_env, achieved_goal_from_observation=IndexIntoAchievedGoal( 'state_observation'), desired_goal_key=state_desired_goal_key, achieved_goal_key=state_achieved_goal_key, ) elif reward_type == 'pixel_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( img_observation_key), desired_goal_key=img_desired_goal_key, ) else: raise ValueError(reward_type) env = ContextualEnv( img_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=img_observation_key, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn env_renderer = EnvRenderer(**env_renderer_kwargs) expl_env, expl_context_distrib, expl_reward = setup_contextual_env( env_id, env_class, env_kwargs, exploration_goal_sampling_mode, env_renderer) eval_env, eval_context_distrib, eval_reward = setup_contextual_env( env_id, env_class, env_kwargs, evaluation_goal_sampling_mode, env_renderer) action_dim = expl_env.action_space.low.size if env_renderer.output_image_format == 'WHC': img_width, img_height, img_num_channels = ( expl_env.observation_space[img_observation_key].shape) elif env_renderer.output_image_format == 'CHW': img_num_channels, img_height, img_width = ( expl_env.observation_space[img_observation_key].shape) else: raise ValueError(env_renderer.output_image_format) def create_qf(): cnn = BasicCNN(input_width=img_width, input_height=img_height, input_channels=img_num_channels, **cnn_kwargs) joint_cnn = ApplyConvToStateAndGoalImage(cnn) return basic.MultiInputSequential( ApplyToObs(joint_cnn), basic.FlattenEachParallel(), ConcatMlp(input_size=joint_cnn.output_size + action_dim, output_size=1, **qf_kwargs)) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() cnn = BasicCNN(input_width=img_width, input_height=img_height, input_channels=img_num_channels, **cnn_kwargs) joint_cnn = ApplyConvToStateAndGoalImage(cnn) policy_obs_dim = joint_cnn.output_size if policy_type == 'normal': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], **policy_kwargs)) policy = PolicyFromDistributionGenerator(Gaussian(obs_processor)) elif policy_type == 'tanh-normal': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], **policy_kwargs)) policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor)) elif policy_type == 'normal-tanh-mean': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], output_activations=['tanh', 'identity'], **policy_kwargs)) policy = PolicyFromDistributionGenerator(Gaussian(obs_processor)) else: raise ValueError("Unknown policy type: {}".format(policy_type)) if apply_random_crops: pad = BatchPad( env_renderer.output_image_format, random_crop_pixel_shift, random_crop_pixel_shift, ) crop = JointRandomCrop( env_renderer.output_image_format, env_renderer.image_shape, ) def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[img_desired_goal_key] obs_padded = pad(obs) next_obs_padded = pad(next_obs) context_padded = pad(context) obs_aug, context_aug = crop(obs_padded, context_padded) next_obs_aug, next_context_aug = crop(next_obs_padded, context_padded) batch['observations'] = np.concatenate([obs_aug, context_aug], axis=1) batch['next_observations'] = np.concatenate( [next_obs_aug, next_context_aug], axis=1) return batch else: def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[img_desired_goal_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[img_desired_goal_key, state_desired_goal_key], observation_key=img_observation_key, observation_keys=[img_observation_key, state_observation_key], context_distribution=eval_context_distrib, sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) trainer = SACTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) video_renderer = EnvRenderer(**video_renderer_kwargs) video_eval_env = InsertImageEnv(eval_env, renderer=video_renderer, image_key='video_observation') video_expl_env = InsertImageEnv(expl_env, renderer=video_renderer, image_key='video_observation') video_eval_env = ContextualEnv( video_eval_env, context_distribution=eval_env.context_distribution, reward_fn=lambda *_: np.array([0]), observation_key=img_observation_key, ) video_expl_env = ContextualEnv( video_expl_env, context_distribution=expl_env.context_distribution, reward_fn=lambda *_: np.array([0]), observation_key=img_observation_key, ) eval_video_func = get_save_video_function( rollout_function, video_eval_env, MakeDeterministic(policy), tag="eval", imsize=video_renderer.image_shape[1], image_formats=[ env_renderer.output_image_format, env_renderer.output_image_format, video_renderer.output_image_format, ], keys_to_show=[ 'image_desired_goal', 'image_observation', 'video_observation' ], **save_video_kwargs) expl_video_func = get_save_video_function( rollout_function, video_expl_env, exploration_policy, tag="xplor", imsize=video_renderer.image_shape[1], image_formats=[ env_renderer.output_image_format, env_renderer.output_image_format, video_renderer.output_image_format, ], keys_to_show=[ 'image_desired_goal', 'image_observation', 'video_observation' ], **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(expl_video_func) algorithm.train()
def contextual_env_distrib_reward(_env_id, _env_class=None, _env_kwargs=None): base_env = get_gym_env( _env_id, env_class=env_class, env_kwargs=env_kwargs, unwrap_timed_envs=True, ) if init_camera: base_env.initialize_camera(init_camera) if (isinstance(stub_env, AntFullPositionGoalEnv) and normalize_distances_for_full_state_ant): base_env = NormalizeAntFullPositionGoalEnv(base_env) normalize_env = base_env else: normalize_env = None env = NoisyAction(base_env, action_noise_scale) diag_fns = [] if is_gym_env: goal_distribution = GoalDictDistributionFromGymGoalEnv( env, desired_goal_key=desired_goal_key, ) diag_fns.append( GenericGoalConditionedContextualDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, )) else: goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) diag_fns.append( GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, )) if isinstance(stub_env, AntFullPositionGoalEnv): diag_fns.append( AntFullPositionGoalEnvDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, normalize_env=normalize_env, )) # if isinstance(stub_env, HopperFullPositionGoalEnv): # diag_fns.append( # HopperFullPositionGoalEnvDiagnostics( # desired_goal_key=desired_goal_key, # achieved_goal_key=achieved_goal_key, # success_threshold=success_threshold, # ) # ) achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, ) if reward_type == 'sparse': distance_fn = L2Distance( achieved_goal_from_observation=achieved_from_ob, desired_goal_key=desired_goal_key, ) reward_fn = ThresholdDistanceReward(distance_fn, success_threshold) elif reward_type == 'negative_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=achieved_from_ob, desired_goal_key=desired_goal_key, ) else: reward_fn = ProbabilisticGoalRewardFn( dynamics_model, state_key=observation_key, context_key=context_key, reward_type=reward_type, discount_factor=discount_factor, ) goal_distribution = PresampledDistribution(goal_distribution, num_presampled_goals) final_env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=diag_fns, update_env_info_fn=delete_info, ) return final_env, goal_distribution, reward_fn
def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode, presampled_goals_path): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) # encoded_env = EncoderWrappedEnv( # img_env, # model, # dict(image_observation="latent_observation", ), # ) # if goal_sampling_mode == "vae_prior": # latent_goal_distribution = PriorDistribution( # model.representation_size, # desired_goal_key, # ) # diagnostics = StateImageGoalDiagnosticsFn({}, ) # elif goal_sampling_mode == "presampled": # diagnostics = state_env.get_contextual_diagnostics # image_goal_distribution = PresampledPathDistribution( # presampled_goals_path, # ) # latent_goal_distribution = AddLatentDistribution( # image_goal_distribution, # image_goal_key, # desired_goal_key, # model, # ) # elif goal_sampling_mode == "reset_of_env": # state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) # state_goal_distribution = GoalDictDistributionFromMultitaskEnv( # state_goal_env, # desired_goal_keys=[state_goal_key], # ) # image_goal_distribution = AddImageDistribution( # env=state_env, # base_distribution=state_goal_distribution, # image_goal_key=image_goal_key, # renderer=renderer, # ) # latent_goal_distribution = AddLatentDistribution( # image_goal_distribution, # image_goal_key, # desired_goal_key, # model, # ) # no_goal_distribution = PriorDistribution( # representation_size=0, # key="no_goal", # ) # diagnostics = state_goal_env.get_contextual_diagnostics # else: # error diagnostics = StateImageGoalDiagnosticsFn({}, ) no_goal_distribution = PriorDistribution( representation_size=0, key="no_goal", ) reward_fn = GraspingRewardFn( # img_env, # state_env, # observation_key=observation_key, # desired_goal_key=desired_goal_key, # **reward_kwargs ) env = ContextualEnv( img_env, # state_env, context_distribution=no_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, no_goal_distribution, reward_fn
def contextual_env_distrib_and_reward(mode='expl'): assert mode in ['expl', 'eval'] env = get_envs(variant) if mode == 'expl': goal_sampling_mode = variant.get('expl_goal_sampling_mode', None) elif mode == 'eval': goal_sampling_mode = variant.get('eval_goal_sampling_mode', None) if goal_sampling_mode not in [None, 'example_set']: env.goal_sampling_mode = goal_sampling_mode mask_ids_for_training = mask_variant.get('mask_ids_for_training', None) if mask_conditioned: context_distrib = MaskDictDistribution( env, desired_goal_keys=[desired_goal_key], mask_format=mask_format, masks=masks, max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None), prev_subtask_weight=mask_variant.get('prev_subtask_weight', None), mask_distr=mask_variant.get('train_mask_distr', None), mask_ids=mask_ids_for_training, ) reward_fn = ContextualMaskingRewardFn( achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, mask_keys=mask_keys, mask_format=mask_format, use_g_for_mean=mask_variant['use_g_for_mean'], use_squared_reward=mask_variant.get('use_squared_reward', False), ) else: if goal_sampling_mode == 'example_set': example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant']) assert len(example_dataset['list_of_waypoints']) == 1 from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet context_distrib = GoalDictDistributionFromSet( example_dataset['list_of_waypoints'][0], desired_goal_keys=[desired_goal_key], ) else: context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None), ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=context_distrib, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None, ) return env, context_distrib, reward_fn
def contextual_env_distrib_and_reward(mode='expl'): assert mode in ['expl', 'eval'] env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) if mode == 'expl': goal_sampling_mode = expl_goal_sampling_mode elif mode == 'eval': goal_sampling_mode = eval_goal_sampling_mode else: goal_sampling_mode = None if goal_sampling_mode is not None: env.goal_sampling_mode = goal_sampling_mode if mask_conditioned: context_distrib = MaskedGoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], mask_keys=mask_keys, mask_dims=mask_dims, mask_format=mask_format, max_subtasks_to_focus_on=max_subtasks_to_focus_on, prev_subtask_weight=prev_subtask_weight, masks=masks, idx_masks=idx_masks, matrix_masks=matrix_masks, mask_distr=train_mask_distr, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( achieved_goal_key), # observation_key desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=contextual_replay_buffer_kwargs.get( 'observation_keys', None), additional_context_keys=mask_keys, reward_fn=partial( mask_reward_fn, mask_format=mask_format, use_g_for_mean=use_g_for_mean ), ) else: context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( achieved_goal_key), # observation_key desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=contextual_replay_buffer_kwargs.get( 'observation_keys', None), ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=context_distrib, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, context_distrib, reward_fn
def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) encoded_env = EncoderWrappedEnv( img_env, model, dict(image_observation="latent_observation", ), ) if goal_sampling_mode == "vae_prior": latent_goal_distribution = PriorDistribution( model.representation_size, desired_goal_key, ) diagnostics = StateImageGoalDiagnosticsFn({}, ) elif goal_sampling_mode == "reset_of_env": state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_goal_env, desired_goal_keys=[state_goal_key], ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=image_goal_key, renderer=renderer, ) latent_goal_distribution = AddLatentDistribution( image_goal_distribution, image_goal_key, desired_goal_key, model, ) if hasattr(state_goal_env, 'goal_conditioned_diagnostics'): diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics( state_goal_env.goal_conditioned_diagnostics, desired_goal_key=state_goal_key, observation_key=state_observation_key, ) else: state_goal_env.get_contextual_diagnostics diagnostics = state_goal_env.get_contextual_diagnostics else: raise NotImplementedError('unknown goal sampling method: %s' % goal_sampling_mode) reward_fn = DistanceRewardFn( observation_key=observation_key, desired_goal_key=desired_goal_key, ) env = ContextualEnv( encoded_env, context_distribution=latent_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, latent_goal_distribution, reward_fn
def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode, env_mask_distribution_type, static_mask=None, ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode use_static_mask = (not do_masking or env_mask_distribution_type == 'static_mask') # Default to all ones mask if static mask isn't defined and should # be using static masks. if use_static_mask and static_mask is None: static_mask = np.ones(mask_dim) if not do_masking: assert env_mask_distribution_type == 'static_mask' goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=mask_dim, distribution_type=env_mask_distribution_type, static_mask=static_mask, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, achieved_goal_from_observation=IndexIntoAchievedGoal( observation_key), ) if do_masking: if masking_reward_fn: reward_fn = partial(masking_reward_fn, context_key=context_key, mask_key=mask_key) else: reward_fn = partial(default_masked_reward_fn, context_key=context_key, mask_key=mask_key) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[state_diag_fn], ) return env, goal_distribution, reward_fn