Example #1
0
def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]:
    try:
        return env.normalization_data
    except AttributeError:
        # TODO: make this a property of EnvWrapper?
        # pyre-fixme[16]: Module `envs` has no attribute `RecSim`.
        if HAS_RECSIM and isinstance(env, RecSim):
            return {
                NormalizationKey.STATE:
                NormalizationData(
                    dense_normalization_parameters=only_continuous_normalizer(
                        list(range(env.observation_space["user"].shape[0])))),
                NormalizationKey.ITEM:
                NormalizationData(
                    dense_normalization_parameters=only_continuous_normalizer(
                        list(range(env.observation_space["doc"]
                                   ["0"].shape[0])))),
            }
        return {
            NormalizationKey.STATE:
            NormalizationData(
                dense_normalization_parameters=build_state_normalizer(env)),
            NormalizationKey.ACTION:
            NormalizationData(
                dense_normalization_parameters=build_action_normalizer(env)),
        }
Example #2
0
 def normalization_data(self):
     return {
         NormalizationKey.STATE:
         NormalizationData(
             dense_normalization_parameters=only_continuous_normalizer(
                 list(range(self.num_arms)), MU_LOW, MU_HIGH))
     }
Example #3
0
def build_state_normalizer(env: EnvWrapper):
    if isinstance(env.observation_space, spaces.Box):
        assert (
            len(env.observation_space.shape) == 1
        ), f"{env.observation_space.shape} has dim > 1, and is not supported."
        return only_continuous_normalizer(
            list(range(env.observation_space.shape[0])),
            env.observation_space.low,
            env.observation_space.high,
        )
    elif isinstance(env.observation_space, spaces.Dict):
        # assuming env.observation_space is image
        return None
    else:
        raise NotImplementedError(f"{env.observation_space} not supported")