def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn
def setup_env(state_env, encoder, reward_fn): goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) if use_image_observations: goal_distribution = AddImageDistribution( env=state_env, base_distribution=goal_distribution, image_goal_key=img_desired_goal_key, renderer=env_renderer, ) base_env = InsertImageEnv(state_env, renderer=env_renderer) goal_distribution = PresampledDistribution( goal_distribution, num_presampled_goals) goal_distribution = EncodedGoalDictDistribution( goal_distribution, encoder=encoder, keys_to_keep=[state_desired_goal_key, img_desired_goal_key], encoder_input_key=img_desired_goal_key, encoder_output_key=latent_desired_goal_key, ) else: base_env = state_env goal_distribution = EncodedGoalDictDistribution( goal_distribution, encoder=encoder, keys_to_keep=[state_desired_goal_key], encoder_input_key=state_desired_goal_key, encoder_output_key=latent_desired_goal_key, ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=latent_dim, distribution_type='one_hot_masks', ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) env = ContextualEnv( base_env, context_distribution=goal_distribution, reward_fn=reward_fn, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, **contextual_env_kwargs, ) return env, goal_distribution
def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode, renderer): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_env.goal_sampling_mode = goal_sampling_mode state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=img_desired_goal_key, renderer=renderer, ) goal_distribution = PresampledDistribution(image_goal_distribution, 5000) img_env = InsertImageEnv(state_env, renderer=renderer) if reward_type == 'state_distance': reward_fn = ContextualRewardFnFromMultitaskEnv( env=state_env, achieved_goal_from_observation=IndexIntoAchievedGoal( 'state_observation'), desired_goal_key=state_desired_goal_key, achieved_goal_key=state_achieved_goal_key, ) elif reward_type == 'pixel_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( img_observation_key), desired_goal_key=img_desired_goal_key, ) else: raise ValueError(reward_type) env = ContextualEnv( img_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=img_observation_key, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn
def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) encoded_env = EncoderWrappedEnv( img_env, model, dict(image_observation="latent_observation", ), ) if goal_sampling_mode == "vae_prior": latent_goal_distribution = PriorDistribution( model.representation_size, desired_goal_key, ) diagnostics = StateImageGoalDiagnosticsFn({}, ) elif goal_sampling_mode == "reset_of_env": state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_goal_env, desired_goal_keys=[state_goal_key], ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=image_goal_key, renderer=renderer, ) latent_goal_distribution = AddLatentDistribution( image_goal_distribution, image_goal_key, desired_goal_key, model, ) if hasattr(state_goal_env, 'goal_conditioned_diagnostics'): diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics( state_goal_env.goal_conditioned_diagnostics, desired_goal_key=state_goal_key, observation_key=state_observation_key, ) else: state_goal_env.get_contextual_diagnostics diagnostics = state_goal_env.get_contextual_diagnostics else: raise NotImplementedError('unknown goal sampling method: %s' % goal_sampling_mode) reward_fn = DistanceRewardFn( observation_key=observation_key, desired_goal_key=desired_goal_key, ) env = ContextualEnv( encoded_env, context_distribution=latent_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, latent_goal_distribution, reward_fn
def contextual_env_distrib_reward(_env_id, _env_class=None, _env_kwargs=None): base_env = get_gym_env( _env_id, env_class=env_class, env_kwargs=env_kwargs, unwrap_timed_envs=True, ) if init_camera: base_env.initialize_camera(init_camera) if (isinstance(stub_env, AntFullPositionGoalEnv) and normalize_distances_for_full_state_ant): base_env = NormalizeAntFullPositionGoalEnv(base_env) normalize_env = base_env else: normalize_env = None env = NoisyAction(base_env, action_noise_scale) diag_fns = [] if is_gym_env: goal_distribution = GoalDictDistributionFromGymGoalEnv( env, desired_goal_key=desired_goal_key, ) diag_fns.append( GenericGoalConditionedContextualDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, )) else: goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) diag_fns.append( GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, )) if isinstance(stub_env, AntFullPositionGoalEnv): diag_fns.append( AntFullPositionGoalEnvDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, normalize_env=normalize_env, )) # if isinstance(stub_env, HopperFullPositionGoalEnv): # diag_fns.append( # HopperFullPositionGoalEnvDiagnostics( # desired_goal_key=desired_goal_key, # achieved_goal_key=achieved_goal_key, # success_threshold=success_threshold, # ) # ) achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, ) if reward_type == 'sparse': distance_fn = L2Distance( achieved_goal_from_observation=achieved_from_ob, desired_goal_key=desired_goal_key, ) reward_fn = ThresholdDistanceReward(distance_fn, success_threshold) elif reward_type == 'negative_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=achieved_from_ob, desired_goal_key=desired_goal_key, ) else: reward_fn = ProbabilisticGoalRewardFn( dynamics_model, state_key=observation_key, context_key=context_key, reward_type=reward_type, discount_factor=discount_factor, ) goal_distribution = PresampledDistribution(goal_distribution, num_presampled_goals) final_env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=diag_fns, update_env_info_fn=delete_info, ) return final_env, goal_distribution, reward_fn
def contextual_env_distrib_and_reward(mode='expl'): assert mode in ['expl', 'eval'] env = get_envs(variant) if mode == 'expl': goal_sampling_mode = variant.get('expl_goal_sampling_mode', None) elif mode == 'eval': goal_sampling_mode = variant.get('eval_goal_sampling_mode', None) if goal_sampling_mode not in [None, 'example_set']: env.goal_sampling_mode = goal_sampling_mode mask_ids_for_training = mask_variant.get('mask_ids_for_training', None) if mask_conditioned: context_distrib = MaskDictDistribution( env, desired_goal_keys=[desired_goal_key], mask_format=mask_format, masks=masks, max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None), prev_subtask_weight=mask_variant.get('prev_subtask_weight', None), mask_distr=mask_variant.get('train_mask_distr', None), mask_ids=mask_ids_for_training, ) reward_fn = ContextualMaskingRewardFn( achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, mask_keys=mask_keys, mask_format=mask_format, use_g_for_mean=mask_variant['use_g_for_mean'], use_squared_reward=mask_variant.get('use_squared_reward', False), ) else: if goal_sampling_mode == 'example_set': example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant']) assert len(example_dataset['list_of_waypoints']) == 1 from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet context_distrib = GoalDictDistributionFromSet( example_dataset['list_of_waypoints'][0], desired_goal_keys=[desired_goal_key], ) else: context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None), ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=context_distrib, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None, ) return env, context_distrib, reward_fn
def contextual_env_distrib_and_reward(mode='expl'): assert mode in ['expl', 'eval'] env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) if mode == 'expl': goal_sampling_mode = expl_goal_sampling_mode elif mode == 'eval': goal_sampling_mode = eval_goal_sampling_mode else: goal_sampling_mode = None if goal_sampling_mode is not None: env.goal_sampling_mode = goal_sampling_mode if mask_conditioned: context_distrib = MaskedGoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], mask_keys=mask_keys, mask_dims=mask_dims, mask_format=mask_format, max_subtasks_to_focus_on=max_subtasks_to_focus_on, prev_subtask_weight=prev_subtask_weight, masks=masks, idx_masks=idx_masks, matrix_masks=matrix_masks, mask_distr=train_mask_distr, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( achieved_goal_key), # observation_key desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=contextual_replay_buffer_kwargs.get( 'observation_keys', None), additional_context_keys=mask_keys, reward_fn=partial( mask_reward_fn, mask_format=mask_format, use_g_for_mean=use_g_for_mean ), ) else: context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( achieved_goal_key), # observation_key desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=contextual_replay_buffer_kwargs.get( 'observation_keys', None), ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=context_distrib, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, context_distrib, reward_fn
def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode, env_mask_distribution_type, static_mask=None, ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode use_static_mask = (not do_masking or env_mask_distribution_type == 'static_mask') # Default to all ones mask if static mask isn't defined and should # be using static masks. if use_static_mask and static_mask is None: static_mask = np.ones(mask_dim) if not do_masking: assert env_mask_distribution_type == 'static_mask' goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=mask_dim, distribution_type=env_mask_distribution_type, static_mask=static_mask, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, achieved_goal_from_observation=IndexIntoAchievedGoal( observation_key), ) if do_masking: if masking_reward_fn: reward_fn = partial(masking_reward_fn, context_key=context_key, mask_key=mask_key) else: reward_fn = partial(default_masked_reward_fn, context_key=context_key, mask_key=mask_key) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[state_diag_fn], ) return env, goal_distribution, reward_fn