def _add_stable_baselines_policies(classes): for k, (cls_name, attr) in classes.items(): try: cls = registry.load_attr(cls_name) fn = _load_stable_baselines(cls, attr) policy_registry.register(k, value=fn) except (AttributeError, ImportError): # We expect PPO1 load to fail if mpi4py isn't installed. # Stable Baselines can be installed without mpi4py. tf.logging.debug(f"Couldn't load {cls_name}. Skipping...")
def test_serialize_identity(env_name, model_cfg, normalize, tmpdir): """Test output actions of deserialized policy are same as original.""" orig_venv = venv = util.make_vec_env(env_name, n_envs=1, parallel=False) vec_normalize = None if normalize: venv = vec_normalize = VecNormalize(venv) model_name, model_cls_name = model_cfg try: model_cls = registry.load_attr(model_cls_name) except (AttributeError, ImportError): # pragma: no cover pytest.skip( "Couldn't load stable baselines class. " "(Probably because mpi4py not installed.)" ) model = model_cls("MlpPolicy", venv) model.learn(1000) venv.env_method("seed", 0) venv.reset() if normalize: # don't want statistics to change as we collect rollouts vec_normalize.training = False orig_rollout = rollout.generate_transitions( model, venv, n_timesteps=1000, deterministic_policy=True, rng=np.random.RandomState(0), ) serialize.save_stable_model(tmpdir, model, vec_normalize) # We use `orig_venv` since `load_policy` automatically wraps `loaded` # with a VecNormalize, when appropriate. with serialize.load_policy(model_name, tmpdir, orig_venv) as loaded: orig_venv.env_method("seed", 0) orig_venv.reset() new_rollout = rollout.generate_transitions( loaded, orig_venv, n_timesteps=1000, deterministic_policy=True, rng=np.random.RandomState(0), ) assert np.allclose(orig_rollout.acts, new_rollout.acts)
def test_serialize_identity(env_name, model_cfg, normalize, tmpdir): """Test output actions of deserialized policy are same as original.""" orig_venv = venv = util.make_vec_env(env_name, n_envs=1, parallel=False) vec_normalize = None if normalize: venv = vec_normalize = VecNormalize(venv) model_name, model_cls_name = model_cfg model_cls = registry.load_attr(model_cls_name) # FIXME(sam): verbose=1 is a hack to stop it from setting up SB logger model = model_cls("MlpPolicy", venv, verbose=1) model.learn(1000) venv.env_method("seed", 0) venv.reset() if normalize: # don't want statistics to change as we collect rollouts vec_normalize.training = False orig_rollout = rollout.generate_transitions( model, venv, n_timesteps=1000, deterministic_policy=True, rng=np.random.RandomState(0), ) serialize.save_stable_model(tmpdir, model, vec_normalize) # We use `orig_venv` since `load_policy` automatically wraps `loaded` # with a VecNormalize, when appropriate. loaded = serialize.load_policy(model_name, tmpdir, orig_venv) orig_venv.env_method("seed", 0) orig_venv.reset() new_rollout = rollout.generate_transitions( loaded, orig_venv, n_timesteps=1000, deterministic_policy=True, rng=np.random.RandomState(0), ) assert np.allclose(orig_rollout.acts, new_rollout.acts)
def _add_stable_baselines_policies(classes): for k, (cls_name, attr) in classes.items(): cls = registry.load_attr(cls_name) fn = _load_stable_baselines(cls, attr) policy_registry.register(k, value=fn)