示例#1
0
 def contextual_env_distrib_and_reward(
         env_id, env_class, env_kwargs, goal_sampling_mode
 ):
     env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
     env.goal_sampling_mode = goal_sampling_mode
     goal_distribution = GoalDictDistributionFromMultitaskEnv(
         env,
         desired_goal_keys=[desired_goal_key],
     )
     reward_fn = ContextualRewardFnFromMultitaskEnv(
         env=env,
         achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
         desired_goal_key=desired_goal_key,
         achieved_goal_key=achieved_goal_key,
     )
     diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
         env.goal_conditioned_diagnostics,
         desired_goal_key=desired_goal_key,
         observation_key=observation_key,
     )
     env = ContextualEnv(
         env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         contextual_diagnostics_fns=[diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
示例#2
0
    def setup_env(state_env, encoder, reward_fn):
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        if use_image_observations:
            goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=goal_distribution,
                image_goal_key=img_desired_goal_key,
                renderer=env_renderer,
            )
            base_env = InsertImageEnv(state_env, renderer=env_renderer)
            goal_distribution = PresampledDistribution(
                goal_distribution, num_presampled_goals)
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
                encoder_input_key=img_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        else:
            base_env = state_env
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key],
                encoder_input_key=state_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=latent_dim,
            distribution_type='one_hot_masks',
        )

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        env = ContextualEnv(
            base_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
            **contextual_env_kwargs,
        )
        return env, goal_distribution
示例#3
0
 def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                          renderer):
     state_env = get_gym_env(env_id,
                             env_class=env_class,
                             env_kwargs=env_kwargs)
     state_env.goal_sampling_mode = goal_sampling_mode
     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
         state_env,
         desired_goal_keys=[state_desired_goal_key],
     )
     state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
         state_env.goal_conditioned_diagnostics,
         desired_goal_key=state_desired_goal_key,
         observation_key=state_observation_key,
     )
     image_goal_distribution = AddImageDistribution(
         env=state_env,
         base_distribution=state_goal_distribution,
         image_goal_key=img_desired_goal_key,
         renderer=renderer,
     )
     goal_distribution = PresampledDistribution(image_goal_distribution,
                                                5000)
     img_env = InsertImageEnv(state_env, renderer=renderer)
     if reward_type == 'state_distance':
         reward_fn = ContextualRewardFnFromMultitaskEnv(
             env=state_env,
             achieved_goal_from_observation=IndexIntoAchievedGoal(
                 'state_observation'),
             desired_goal_key=state_desired_goal_key,
             achieved_goal_key=state_achieved_goal_key,
         )
     elif reward_type == 'pixel_distance':
         reward_fn = NegativeL2Distance(
             achieved_goal_from_observation=IndexIntoAchievedGoal(
                 img_observation_key),
             desired_goal_key=img_desired_goal_key,
         )
     else:
         raise ValueError(reward_type)
     env = ContextualEnv(
         img_env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=img_observation_key,
         contextual_diagnostics_fns=[state_diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
示例#4
0
 def get_video_func(
     env,
     policy,
     tag,
 ):
     renderer = EnvRenderer(**renderer_kwargs)
     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
         env,
         desired_goal_keys=[desired_goal_key],
     )
     image_goal_distribution = AddImageDistribution(
         env=env,
         base_distribution=state_goal_distribution,
         image_goal_key="image_desired_goal",
         renderer=renderer,
     )
     img_env = InsertImageEnv(env, renderer=renderer)
     rollout_function = partial(
         rf.multitask_rollout,
         max_path_length=variant["max_path_length"],
         observation_key=observation_key,
         desired_goal_key=desired_goal_key,
         return_dict_obs=True,
     )
     reward_fn = ContextualRewardFnFromMultitaskEnv(
         env=env,
         achieved_goal_from_observation=IndexIntoAchievedGoal(
             observation_key),
         desired_goal_key=desired_goal_key,
         achieved_goal_key="state_achieved_goal",
     )
     contextual_env = ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
     )
     video_func = get_save_video_function(
         rollout_function,
         contextual_env,
         policy,
         tag=tag,
         imsize=renderer.width,
         image_format="CWH",
         **save_video_kwargs,
     )
     return video_func
示例#5
0
    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)

        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        encoded_env = EncoderWrappedEnv(
            img_env,
            model,
            dict(image_observation="latent_observation", ),
        )
        if goal_sampling_mode == "vae_prior":
            latent_goal_distribution = PriorDistribution(
                model.representation_size,
                desired_goal_key,
            )
            diagnostics = StateImageGoalDiagnosticsFn({}, )
        elif goal_sampling_mode == "reset_of_env":
            state_goal_env = get_gym_env(env_id,
                                         env_class=env_class,
                                         env_kwargs=env_kwargs)
            state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
                state_goal_env,
                desired_goal_keys=[state_goal_key],
            )
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_goal_distribution,
                image_goal_key=image_goal_key,
                renderer=renderer,
            )
            latent_goal_distribution = AddLatentDistribution(
                image_goal_distribution,
                image_goal_key,
                desired_goal_key,
                model,
            )
            if hasattr(state_goal_env, 'goal_conditioned_diagnostics'):
                diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics(
                    state_goal_env.goal_conditioned_diagnostics,
                    desired_goal_key=state_goal_key,
                    observation_key=state_observation_key,
                )
            else:
                state_goal_env.get_contextual_diagnostics
                diagnostics = state_goal_env.get_contextual_diagnostics
        else:
            raise NotImplementedError('unknown goal sampling method: %s' %
                                      goal_sampling_mode)

        reward_fn = DistanceRewardFn(
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )

        env = ContextualEnv(
            encoded_env,
            context_distribution=latent_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, latent_goal_distribution, reward_fn
示例#6
0
 def contextual_env_distrib_reward(_env_id,
                                   _env_class=None,
                                   _env_kwargs=None):
     base_env = get_gym_env(
         _env_id,
         env_class=env_class,
         env_kwargs=env_kwargs,
         unwrap_timed_envs=True,
     )
     if init_camera:
         base_env.initialize_camera(init_camera)
     if (isinstance(stub_env, AntFullPositionGoalEnv)
             and normalize_distances_for_full_state_ant):
         base_env = NormalizeAntFullPositionGoalEnv(base_env)
         normalize_env = base_env
     else:
         normalize_env = None
     env = NoisyAction(base_env, action_noise_scale)
     diag_fns = []
     if is_gym_env:
         goal_distribution = GoalDictDistributionFromGymGoalEnv(
             env,
             desired_goal_key=desired_goal_key,
         )
         diag_fns.append(
             GenericGoalConditionedContextualDiagnostics(
                 desired_goal_key=desired_goal_key,
                 achieved_goal_key=achieved_goal_key,
                 success_threshold=success_threshold,
             ))
     else:
         goal_distribution = GoalDictDistributionFromMultitaskEnv(
             env,
             desired_goal_keys=[desired_goal_key],
         )
         diag_fns.append(
             GoalConditionedDiagnosticsToContextualDiagnostics(
                 env.goal_conditioned_diagnostics,
                 desired_goal_key=desired_goal_key,
                 observation_key=observation_key,
             ))
     if isinstance(stub_env, AntFullPositionGoalEnv):
         diag_fns.append(
             AntFullPositionGoalEnvDiagnostics(
                 desired_goal_key=desired_goal_key,
                 achieved_goal_key=achieved_goal_key,
                 success_threshold=success_threshold,
                 normalize_env=normalize_env,
             ))
     # if isinstance(stub_env, HopperFullPositionGoalEnv):
     #     diag_fns.append(
     #         HopperFullPositionGoalEnvDiagnostics(
     #             desired_goal_key=desired_goal_key,
     #             achieved_goal_key=achieved_goal_key,
     #             success_threshold=success_threshold,
     #         )
     #     )
     achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, )
     if reward_type == 'sparse':
         distance_fn = L2Distance(
             achieved_goal_from_observation=achieved_from_ob,
             desired_goal_key=desired_goal_key,
         )
         reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
     elif reward_type == 'negative_distance':
         reward_fn = NegativeL2Distance(
             achieved_goal_from_observation=achieved_from_ob,
             desired_goal_key=desired_goal_key,
         )
     else:
         reward_fn = ProbabilisticGoalRewardFn(
             dynamics_model,
             state_key=observation_key,
             context_key=context_key,
             reward_type=reward_type,
             discount_factor=discount_factor,
         )
     goal_distribution = PresampledDistribution(goal_distribution,
                                                num_presampled_goals)
     final_env = ContextualEnv(
         env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         contextual_diagnostics_fns=diag_fns,
         update_env_info_fn=delete_info,
     )
     return final_env, goal_distribution, reward_fn
示例#7
0
    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_envs(variant)

        if mode == 'expl':
            goal_sampling_mode = variant.get('expl_goal_sampling_mode', None)
        elif mode == 'eval':
            goal_sampling_mode = variant.get('eval_goal_sampling_mode', None)
        if goal_sampling_mode not in [None, 'example_set']:
            env.goal_sampling_mode = goal_sampling_mode

        mask_ids_for_training = mask_variant.get('mask_ids_for_training', None)

        if mask_conditioned:
            context_distrib = MaskDictDistribution(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_format=mask_format,
                masks=masks,
                max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None),
                prev_subtask_weight=mask_variant.get('prev_subtask_weight', None),
                mask_distr=mask_variant.get('train_mask_distr', None),
                mask_ids=mask_ids_for_training,
            )
            reward_fn = ContextualMaskingRewardFn(
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                mask_keys=mask_keys,
                mask_format=mask_format,
                use_g_for_mean=mask_variant['use_g_for_mean'],
                use_squared_reward=mask_variant.get('use_squared_reward', False),
            )
        else:
            if goal_sampling_mode == 'example_set':
                example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant'])
                assert len(example_dataset['list_of_waypoints']) == 1
                from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet
                context_distrib = GoalDictDistributionFromSet(
                    example_dataset['list_of_waypoints'][0],
                    desired_goal_keys=[desired_goal_key],
                )
            else:
                context_distrib = GoalDictDistributionFromMultitaskEnv(
                    env,
                    desired_goal_keys=[desired_goal_key],
                )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None,
        )
        return env, context_distrib, reward_fn
示例#8
0
    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)

        if mode == 'expl':
            goal_sampling_mode = expl_goal_sampling_mode
        elif mode == 'eval':
            goal_sampling_mode = eval_goal_sampling_mode
        else:
            goal_sampling_mode = None
        if goal_sampling_mode is not None:
            env.goal_sampling_mode = goal_sampling_mode

        if mask_conditioned:
            context_distrib = MaskedGoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_keys=mask_keys,
                mask_dims=mask_dims,
                mask_format=mask_format,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                prev_subtask_weight=prev_subtask_weight,
                masks=masks,
                idx_masks=idx_masks,
                matrix_masks=matrix_masks,
                mask_distr=train_mask_distr,
            )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    achieved_goal_key),  # observation_key
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=contextual_replay_buffer_kwargs.get(
                    'observation_keys', None),
                additional_context_keys=mask_keys,
                reward_fn=partial(
                    mask_reward_fn,
                    mask_format=mask_format,
                    use_g_for_mean=use_g_for_mean
                ),
            )
        else:
            context_distrib = GoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
            )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    achieved_goal_key),  # observation_key
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=contextual_replay_buffer_kwargs.get(
                    'observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, context_distrib, reward_fn
示例#9
0
    def contextual_env_distrib_and_reward(
        env_id,
        env_class,
        env_kwargs,
        goal_sampling_mode,
        env_mask_distribution_type,
        static_mask=None,
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        use_static_mask = (not do_masking
                           or env_mask_distribution_type == 'static_mask')
        # Default to all ones mask if static mask isn't defined and should
        # be using static masks.
        if use_static_mask and static_mask is None:
            static_mask = np.ones(mask_dim)
        if not do_masking:
            assert env_mask_distribution_type == 'static_mask'

        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=mask_dim,
            distribution_type=env_mask_distribution_type,
            static_mask=static_mask,
        )

        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                observation_key),
        )

        if do_masking:
            if masking_reward_fn:
                reward_fn = partial(masking_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)
            else:
                reward_fn = partial(default_masked_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )

        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
        )
        return env, goal_distribution, reward_fn
示例#10
0
def masking_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    achieved_goal_key='state_achieved_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    train_env_id=None,
    eval_env_id=None,
    do_masking=True,
    mask_key="masked_observation",
    masking_eval_steps=200,
    log_mask_diagnostics=True,
    mask_dim=None,
    masking_reward_fn=None,
    masking_for_exploration=True,
    rotate_masks_for_eval=False,
    rotate_masks_for_expl=False,
    mask_distribution=None,
    num_steps_per_mask_change=10,
    tag=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    context_key = desired_goal_key
    env = get_gym_env(train_env_id, env_class=env_class, env_kwargs=env_kwargs)
    mask_dim = (mask_dim or env.observation_space.spaces[context_key].low.size)

    if not do_masking:
        mask_distribution = 'all_ones'

    assert mask_distribution in [
        'one_hot_masks', 'random_bit_masks', 'all_ones'
    ]
    if mask_distribution == 'all_ones':
        mask_distribution = 'static_mask'

    def contextual_env_distrib_and_reward(
        env_id,
        env_class,
        env_kwargs,
        goal_sampling_mode,
        env_mask_distribution_type,
        static_mask=None,
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        use_static_mask = (not do_masking
                           or env_mask_distribution_type == 'static_mask')
        # Default to all ones mask if static mask isn't defined and should
        # be using static masks.
        if use_static_mask and static_mask is None:
            static_mask = np.ones(mask_dim)
        if not do_masking:
            assert env_mask_distribution_type == 'static_mask'

        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=mask_dim,
            distribution_type=env_mask_distribution_type,
            static_mask=static_mask,
        )

        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                observation_key),
        )

        if do_masking:
            if masking_reward_fn:
                reward_fn = partial(masking_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)
            else:
                reward_fn = partial(default_masked_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )

        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
        )
        return env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        train_env_id,
        env_class,
        env_kwargs,
        exploration_goal_sampling_mode,
        mask_distribution if masking_for_exploration else 'static_mask',
    )
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        'static_mask')

    # Distribution for relabeling
    relabel_context_distrib = GoalDictDistributionFromMultitaskEnv(
        env,
        desired_goal_keys=[desired_goal_key],
    )
    relabel_context_distrib = MaskedGoalDictDistribution(
        relabel_context_distrib,
        mask_key=mask_key,
        mask_dim=mask_dim,
        distribution_type=mask_distribution,
        static_mask=None if do_masking else np.ones(mask_dim),
    )

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size +
               mask_dim)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def context_from_obs_dict_fn(obs_dict):
        achieved_goal = obs_dict['state_achieved_goal']
        # Should the mask be randomized for future relabeling?
        # batch_size = len(achieved_goal)
        # mask = np.random.choice(2, size=(batch_size, mask_dim))
        mask = obs_dict[mask_key]
        return {
            mask_key: mask,
            context_key: achieved_goal,
        }

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']

        context = batch[context_key]
        mask = batch[mask_key]

        batch['observations'] = np.concatenate([obs, context, mask], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context, mask],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key, mask_key],
        observation_keys=[observation_key, 'state_achieved_goal'],
        context_distribution=relabel_context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    def create_path_collector(env, policy, is_rotating):
        if is_rotating:
            assert do_masking and mask_distribution == 'one_hot_masks'
            return RotatingMaskingPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
                mask_key=mask_key,
                mask_length=mask_dim,
                num_steps_per_mask_change=num_steps_per_mask_change,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
            )

    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)

    eval_path_collector = create_path_collector(eval_env,
                                                MakeDeterministic(policy),
                                                rotate_masks_for_eval)
    expl_path_collector = create_path_collector(expl_env, exploration_policy,
                                                rotate_masks_for_expl)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
            # Eval on everything for the base video
            obs_processor=lambda o: np.hstack(
                (o[observation_key], o[context_key], np.ones(mask_dim))))
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImageEnv(state_env, renderer=renderer)
            return ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                # update_env_info_fn=DeleteOldEnvInfo(),
            )

        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    # For diagnostics, evaluate the policy on each individual dimension of the
    # mask.
    masks = []
    collectors = []
    for mask_idx in range(mask_dim):
        mask = np.zeros(mask_dim)
        mask[mask_idx] = 1
        masks.append(mask)

    for mask in masks:
        for_dim_mask = mask if do_masking else np.ones(mask_dim)
        masked_env, _, _ = contextual_env_distrib_and_reward(
            eval_env_id,
            env_class,
            env_kwargs,
            evaluation_goal_sampling_mode,
            'static_mask',
            static_mask=for_dim_mask)

        collector = ContextualPathCollector(
            masked_env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
        )
        collectors.append(collector)
    log_prefixes = [
        'mask_{}/'.format(''.join(mask.astype(int).astype(str)))
        for mask in masks
    ]

    def get_mask_diagnostics(unused):
        from rlkit.core.logging import append_log, add_prefix, OrderedDict
        from rlkit.misc import eval_util
        log = OrderedDict()
        for prefix, collector in zip(log_prefixes, collectors):
            paths = collector.collect_new_paths(
                max_path_length,
                masking_eval_steps,
                discard_incomplete_paths=True,
            )
            generic_info = add_prefix(
                eval_util.get_generic_path_information(paths),
                prefix,
            )
            append_log(log, generic_info)

        for collector in collectors:
            collector.end_epoch(0)
        return log

    if log_mask_diagnostics:
        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)

    algorithm.train()