コード例 #1
0
    "maze.normalization_strategies.MeanZeroStdOneObservationNormalizationStrategy",
    default_strategy_config={
        "clip_range": (None, None),
        "axis": 0
    },
    default_statistics=None,
    statistics_dump="statistics.pkl",
    sampling_policy=RandomPolicy(env.action_spaces_dict),
    exclude=None,
    manual_config=None)

# next we estimate the normalization statistics by
# (1) collecting observations by randomly sampling 1000 transitions from the environment
# (2) computing the statistics according to the define normalization strategy
normalization_statistics = obtain_normalization_statistics(env, n_samples=1000)
env.set_normalization_statistics(normalization_statistics)

# after this step all observations returned by the environment will be normalized

# stable-baselines does not support dict spaces so we have to remove them
env = NoDictSpacesWrapper(env)

# TRAINING AND ROLLOUT (remains unchanged)
# ----------------------------------------

model = A2C('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=10000)

obs = env.reset()
for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
コード例 #2
0
    "statistics_dump": "statistics.pkl",
    "sampling_policy": RandomPolicy(env.action_spaces_dict),
    "exclude": None,
    "manual_config": None
}

# 1. PREPARATION: first we estimate normalization statistics
# ----------------------------------------------------------

# wrap the environment for observation normalization
env = ObservationNormalizationWrapper.wrap(env, **normalization_config)

# before we can start working with normalized observations
# we need to estimate the normalization statistics
normalization_statistics = obtain_normalization_statistics(env, n_samples=1000)

# 2. APPLICATION (training, rollout, deployment)
# ----------------------------------------------

# instantiate a maze environment
training_env = GymMazeEnv("CartPole-v0")
# wrap the environment for observation normalization
training_env = ObservationNormalizationWrapper.wrap(training_env,
                                                    **normalization_config)

# reuse the estimated the statistics in our training environment(s)
training_env.set_normalization_statistics(normalization_statistics)

# after this step the training env yields normalized observations
normalized_obs = training_env.reset()