def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn
def setup_env(state_env, encoder, reward_fn): goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) if use_image_observations: goal_distribution = AddImageDistribution( env=state_env, base_distribution=goal_distribution, image_goal_key=img_desired_goal_key, renderer=env_renderer, ) base_env = InsertImageEnv(state_env, renderer=env_renderer) goal_distribution = PresampledDistribution( goal_distribution, num_presampled_goals) goal_distribution = EncodedGoalDictDistribution( goal_distribution, encoder=encoder, keys_to_keep=[state_desired_goal_key, img_desired_goal_key], encoder_input_key=img_desired_goal_key, encoder_output_key=latent_desired_goal_key, ) else: base_env = state_env goal_distribution = EncodedGoalDictDistribution( goal_distribution, encoder=encoder, keys_to_keep=[state_desired_goal_key], encoder_input_key=state_desired_goal_key, encoder_output_key=latent_desired_goal_key, ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=latent_dim, distribution_type='one_hot_masks', ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) env = ContextualEnv( base_env, context_distribution=goal_distribution, reward_fn=reward_fn, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, **contextual_env_kwargs, ) return env, goal_distribution
def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode, renderer): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_env.goal_sampling_mode = goal_sampling_mode state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=img_desired_goal_key, renderer=renderer, ) goal_distribution = PresampledDistribution(image_goal_distribution, 5000) img_env = InsertImageEnv(state_env, renderer=renderer) if reward_type == 'state_distance': reward_fn = ContextualRewardFnFromMultitaskEnv( env=state_env, achieved_goal_from_observation=IndexIntoAchievedGoal( 'state_observation'), desired_goal_key=state_desired_goal_key, achieved_goal_key=state_achieved_goal_key, ) elif reward_type == 'pixel_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( img_observation_key), desired_goal_key=img_desired_goal_key, ) else: raise ValueError(reward_type) env = ContextualEnv( img_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=img_observation_key, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn
def get_video_func( env, policy, tag, ): renderer = EnvRenderer(**renderer_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) image_goal_distribution = AddImageDistribution( env=env, base_distribution=state_goal_distribution, image_goal_key="image_desired_goal", renderer=renderer, ) img_env = InsertImageEnv(env, renderer=renderer) rollout_function = partial( rf.multitask_rollout, max_path_length=variant["max_path_length"], observation_key=observation_key, desired_goal_key=desired_goal_key, return_dict_obs=True, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( observation_key), desired_goal_key=desired_goal_key, achieved_goal_key="state_achieved_goal", ) contextual_env = ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, ) video_func = get_save_video_function( rollout_function, contextual_env, policy, tag=tag, imsize=renderer.width, image_format="CWH", **save_video_kwargs, ) return video_func
def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) encoded_env = EncoderWrappedEnv( img_env, model, dict(image_observation="latent_observation", ), ) if goal_sampling_mode == "vae_prior": latent_goal_distribution = PriorDistribution( model.representation_size, desired_goal_key, ) diagnostics = StateImageGoalDiagnosticsFn({}, ) elif goal_sampling_mode == "reset_of_env": state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_goal_env, desired_goal_keys=[state_goal_key], ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=image_goal_key, renderer=renderer, ) latent_goal_distribution = AddLatentDistribution( image_goal_distribution, image_goal_key, desired_goal_key, model, ) if hasattr(state_goal_env, 'goal_conditioned_diagnostics'): diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics( state_goal_env.goal_conditioned_diagnostics, desired_goal_key=state_goal_key, observation_key=state_observation_key, ) else: state_goal_env.get_contextual_diagnostics diagnostics = state_goal_env.get_contextual_diagnostics else: raise NotImplementedError('unknown goal sampling method: %s' % goal_sampling_mode) reward_fn = DistanceRewardFn( observation_key=observation_key, desired_goal_key=desired_goal_key, ) env = ContextualEnv( encoded_env, context_distribution=latent_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, latent_goal_distribution, reward_fn
def contextual_env_distrib_reward(_env_id, _env_class=None, _env_kwargs=None): base_env = get_gym_env( _env_id, env_class=env_class, env_kwargs=env_kwargs, unwrap_timed_envs=True, ) if init_camera: base_env.initialize_camera(init_camera) if (isinstance(stub_env, AntFullPositionGoalEnv) and normalize_distances_for_full_state_ant): base_env = NormalizeAntFullPositionGoalEnv(base_env) normalize_env = base_env else: normalize_env = None env = NoisyAction(base_env, action_noise_scale) diag_fns = [] if is_gym_env: goal_distribution = GoalDictDistributionFromGymGoalEnv( env, desired_goal_key=desired_goal_key, ) diag_fns.append( GenericGoalConditionedContextualDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, )) else: goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) diag_fns.append( GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, )) if isinstance(stub_env, AntFullPositionGoalEnv): diag_fns.append( AntFullPositionGoalEnvDiagnostics( desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, success_threshold=success_threshold, normalize_env=normalize_env, )) # if isinstance(stub_env, HopperFullPositionGoalEnv): # diag_fns.append( # HopperFullPositionGoalEnvDiagnostics( # desired_goal_key=desired_goal_key, # achieved_goal_key=achieved_goal_key, # success_threshold=success_threshold, # ) # ) achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, ) if reward_type == 'sparse': distance_fn = L2Distance( achieved_goal_from_observation=achieved_from_ob, desired_goal_key=desired_goal_key, ) reward_fn = ThresholdDistanceReward(distance_fn, success_threshold) elif reward_type == 'negative_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=achieved_from_ob, desired_goal_key=desired_goal_key, ) else: reward_fn = ProbabilisticGoalRewardFn( dynamics_model, state_key=observation_key, context_key=context_key, reward_type=reward_type, discount_factor=discount_factor, ) goal_distribution = PresampledDistribution(goal_distribution, num_presampled_goals) final_env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=diag_fns, update_env_info_fn=delete_info, ) return final_env, goal_distribution, reward_fn
def contextual_env_distrib_and_reward(mode='expl'): assert mode in ['expl', 'eval'] env = get_envs(variant) if mode == 'expl': goal_sampling_mode = variant.get('expl_goal_sampling_mode', None) elif mode == 'eval': goal_sampling_mode = variant.get('eval_goal_sampling_mode', None) if goal_sampling_mode not in [None, 'example_set']: env.goal_sampling_mode = goal_sampling_mode mask_ids_for_training = mask_variant.get('mask_ids_for_training', None) if mask_conditioned: context_distrib = MaskDictDistribution( env, desired_goal_keys=[desired_goal_key], mask_format=mask_format, masks=masks, max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None), prev_subtask_weight=mask_variant.get('prev_subtask_weight', None), mask_distr=mask_variant.get('train_mask_distr', None), mask_ids=mask_ids_for_training, ) reward_fn = ContextualMaskingRewardFn( achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, mask_keys=mask_keys, mask_format=mask_format, use_g_for_mean=mask_variant['use_g_for_mean'], use_squared_reward=mask_variant.get('use_squared_reward', False), ) else: if goal_sampling_mode == 'example_set': example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant']) assert len(example_dataset['list_of_waypoints']) == 1 from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet context_distrib = GoalDictDistributionFromSet( example_dataset['list_of_waypoints'][0], desired_goal_keys=[desired_goal_key], ) else: context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key), desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None), ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=context_distrib, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None, ) return env, context_distrib, reward_fn
def contextual_env_distrib_and_reward(mode='expl'): assert mode in ['expl', 'eval'] env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) if mode == 'expl': goal_sampling_mode = expl_goal_sampling_mode elif mode == 'eval': goal_sampling_mode = eval_goal_sampling_mode else: goal_sampling_mode = None if goal_sampling_mode is not None: env.goal_sampling_mode = goal_sampling_mode if mask_conditioned: context_distrib = MaskedGoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], mask_keys=mask_keys, mask_dims=mask_dims, mask_format=mask_format, max_subtasks_to_focus_on=max_subtasks_to_focus_on, prev_subtask_weight=prev_subtask_weight, masks=masks, idx_masks=idx_masks, matrix_masks=matrix_masks, mask_distr=train_mask_distr, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( achieved_goal_key), # observation_key desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=contextual_replay_buffer_kwargs.get( 'observation_keys', None), additional_context_keys=mask_keys, reward_fn=partial( mask_reward_fn, mask_format=mask_format, use_g_for_mean=use_g_for_mean ), ) else: context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( achieved_goal_key), # observation_key desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, additional_obs_keys=contextual_replay_buffer_kwargs.get( 'observation_keys', None), ) diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=context_distrib, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diag_fn], update_env_info_fn=delete_info, ) return env, context_distrib, reward_fn
def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode, env_mask_distribution_type, static_mask=None, ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode use_static_mask = (not do_masking or env_mask_distribution_type == 'static_mask') # Default to all ones mask if static mask isn't defined and should # be using static masks. if use_static_mask and static_mask is None: static_mask = np.ones(mask_dim) if not do_masking: assert env_mask_distribution_type == 'static_mask' goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=mask_dim, distribution_type=env_mask_distribution_type, static_mask=static_mask, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, achieved_goal_from_observation=IndexIntoAchievedGoal( observation_key), ) if do_masking: if masking_reward_fn: reward_fn = partial(masking_reward_fn, context_key=context_key, mask_key=mask_key) else: reward_fn = partial(default_masked_reward_fn, context_key=context_key, mask_key=mask_key) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[state_diag_fn], ) return env, goal_distribution, reward_fn
def masking_sac_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, # Video parameters save_video=True, save_video_kwargs=None, renderer_kwargs=None, train_env_id=None, eval_env_id=None, do_masking=True, mask_key="masked_observation", masking_eval_steps=200, log_mask_diagnostics=True, mask_dim=None, masking_reward_fn=None, masking_for_exploration=True, rotate_masks_for_eval=False, rotate_masks_for_expl=False, mask_distribution=None, num_steps_per_mask_change=10, tag=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not renderer_kwargs: renderer_kwargs = {} context_key = desired_goal_key env = get_gym_env(train_env_id, env_class=env_class, env_kwargs=env_kwargs) mask_dim = (mask_dim or env.observation_space.spaces[context_key].low.size) if not do_masking: mask_distribution = 'all_ones' assert mask_distribution in [ 'one_hot_masks', 'random_bit_masks', 'all_ones' ] if mask_distribution == 'all_ones': mask_distribution = 'static_mask' def contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, goal_sampling_mode, env_mask_distribution_type, static_mask=None, ): env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env.goal_sampling_mode = goal_sampling_mode use_static_mask = (not do_masking or env_mask_distribution_type == 'static_mask') # Default to all ones mask if static mask isn't defined and should # be using static masks. if use_static_mask and static_mask is None: static_mask = np.ones(mask_dim) if not do_masking: assert env_mask_distribution_type == 'static_mask' goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=mask_dim, distribution_type=env_mask_distribution_type, static_mask=static_mask, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, achieved_goal_from_observation=IndexIntoAchievedGoal( observation_key), ) if do_masking: if masking_reward_fn: reward_fn = partial(masking_reward_fn, context_key=context_key, mask_key=mask_key) else: reward_fn = partial(default_masked_reward_fn, context_key=context_key, mask_key=mask_key) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( env.goal_conditioned_diagnostics, desired_goal_key=desired_goal_key, observation_key=observation_key, ) env = ContextualEnv( env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[state_diag_fn], ) return env, goal_distribution, reward_fn expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward( train_env_id, env_class, env_kwargs, exploration_goal_sampling_mode, mask_distribution if masking_for_exploration else 'static_mask', ) eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward( eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode, 'static_mask') # Distribution for relabeling relabel_context_distrib = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) relabel_context_distrib = MaskedGoalDictDistribution( relabel_context_distrib, mask_key=mask_key, mask_dim=mask_dim, distribution_type=mask_distribution, static_mask=None if do_masking else np.ones(mask_dim), ) obs_dim = (expl_env.observation_space.spaces[observation_key].low.size + expl_env.observation_space.spaces[context_key].low.size + mask_dim) action_dim = expl_env.action_space.low.size def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs) def context_from_obs_dict_fn(obs_dict): achieved_goal = obs_dict['state_achieved_goal'] # Should the mask be randomized for future relabeling? # batch_size = len(achieved_goal) # mask = np.random.choice(2, size=(batch_size, mask_dim)) mask = obs_dict[mask_key] return { mask_key: mask, context_key: achieved_goal, } def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[context_key] mask = batch[mask_key] batch['observations'] = np.concatenate([obs, context, mask], axis=1) batch['next_observations'] = np.concatenate([next_obs, context, mask], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[context_key, mask_key], observation_keys=[observation_key, 'state_achieved_goal'], context_distribution=relabel_context_distrib, sample_context_from_obs_dict_fn=context_from_obs_dict_fn, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) trainer = SACTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) def create_path_collector(env, policy, is_rotating): if is_rotating: assert do_masking and mask_distribution == 'one_hot_masks' return RotatingMaskingPathCollector( env, policy, observation_key=observation_key, context_keys_for_policy=[context_key, mask_key], mask_key=mask_key, mask_length=mask_dim, num_steps_per_mask_change=num_steps_per_mask_change, ) else: return ContextualPathCollector( env, policy, observation_key=observation_key, context_keys_for_policy=[context_key, mask_key], ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) eval_path_collector = create_path_collector(eval_env, MakeDeterministic(policy), rotate_masks_for_eval) expl_path_collector = create_path_collector(expl_env, exploration_policy, rotate_masks_for_expl) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=observation_key, context_keys_for_policy=[context_key, mask_key], # Eval on everything for the base video obs_processor=lambda o: np.hstack( (o[observation_key], o[context_key], np.ones(mask_dim)))) renderer = EnvRenderer(**renderer_kwargs) def add_images(env, state_distribution): state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=renderer, ) img_env = InsertImageEnv(state_env, renderer=renderer) return ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=eval_reward, observation_key=observation_key, # update_env_info_fn=DeleteOldEnvInfo(), ) img_eval_env = add_images(eval_env, eval_context_distrib) img_expl_env = add_images(expl_env, expl_context_distrib) eval_video_func = get_save_video_function( rollout_function, img_eval_env, MakeDeterministic(policy), tag="eval", imsize=renderer.image_shape[0], image_format='CWH', **save_video_kwargs) expl_video_func = get_save_video_function( rollout_function, img_expl_env, exploration_policy, tag="train", imsize=renderer.image_shape[0], image_format='CWH', **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(expl_video_func) # For diagnostics, evaluate the policy on each individual dimension of the # mask. masks = [] collectors = [] for mask_idx in range(mask_dim): mask = np.zeros(mask_dim) mask[mask_idx] = 1 masks.append(mask) for mask in masks: for_dim_mask = mask if do_masking else np.ones(mask_dim) masked_env, _, _ = contextual_env_distrib_and_reward( eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode, 'static_mask', static_mask=for_dim_mask) collector = ContextualPathCollector( masked_env, MakeDeterministic(policy), observation_key=observation_key, context_keys_for_policy=[context_key, mask_key], ) collectors.append(collector) log_prefixes = [ 'mask_{}/'.format(''.join(mask.astype(int).astype(str))) for mask in masks ] def get_mask_diagnostics(unused): from rlkit.core.logging import append_log, add_prefix, OrderedDict from rlkit.misc import eval_util log = OrderedDict() for prefix, collector in zip(log_prefixes, collectors): paths = collector.collect_new_paths( max_path_length, masking_eval_steps, discard_incomplete_paths=True, ) generic_info = add_prefix( eval_util.get_generic_path_information(paths), prefix, ) append_log(log, generic_info) for collector in collectors: collector.end_epoch(0) return log if log_mask_diagnostics: algorithm._eval_get_diag_fns.append(get_mask_diagnostics) algorithm.train()