def test_keys(): reg = registry.Registry() with pytest.raises(KeyError, match="not registered"): reg.get("foobar") reg.register(key="foobar", value="fizzbuzz") assert reg.get("foobar") == "fizzbuzz" with pytest.raises(KeyError, match="Duplicate registration"): reg.register(key="foobar", value="fizzbuzz")
def test_lazy(): """Test indirect/lazy loading of registered values.""" reg = registry.Registry() reg.register("nomodule", indirect="this.module.does.not.exist:foobar") with pytest.raises(ImportError): reg.get("nomodule") reg.register("noattribute", indirect="imitation:attr_does_not_exist") with pytest.raises(AttributeError): reg.get("noattribute") with pytest.raises(ValueError, match="exactly one of"): reg.register(key="wrongargs", value=3.14, indirect="math:pi") reg.register("exists", indirect="math:pi") val = reg.get("exists") import math assert val == math.pi
import contextlib import os import pickle from typing import Callable, ContextManager, Iterator, Optional, Type import tensorflow as tf from stable_baselines.common.base_class import BaseRLModel from stable_baselines.common.policies import BasePolicy from stable_baselines.common.vec_env import VecEnv, VecNormalize from imitation.policies.base import RandomPolicy, ZeroPolicy from imitation.util import registry PolicyLoaderFn = Callable[[str, VecEnv], ContextManager[BasePolicy]] policy_registry: registry.Registry[PolicyLoaderFn] = registry.Registry() class NormalizePolicy(BasePolicy): """Wraps a policy, normalizing its input observations. `VecNormalize` normalizes observations to have zero mean and unit standard deviation. To do this, it collects statistics on the observations. We must restore these statistics when we load the policy, or we will be feeding observations in of a different scale to those the policy was trained with. It is convenient to do this when loading the policy, so users of a saved policy are not responsible for this implementation detail. WARNING: This trick will not work for fine-tuning / training policies. """ def __init__(self, policy: BasePolicy, vec_normalize: VecNormalize):
from typing import Callable import numpy as np import torch as th from stable_baselines3.common.vec_env import VecEnv from imitation.rewards import common from imitation.util import registry, util # TODO(sam): I suspect this whole file can be replaced with th.load calls. Try # that refactoring once I have things running. RewardFnLoaderFn = Callable[[str, VecEnv], common.RewardFn] reward_registry: registry.Registry[RewardFnLoaderFn] = registry.Registry() def _load_discrim_net(path: str, venv: VecEnv) -> common.RewardFn: """Load test reward output from discriminator.""" del venv # Unused. discriminator = th.load(path) # TODO(gleave): expose train reward as well? (hard due to action probs?) return discriminator.predict_reward_test def _load_reward_net_as_fn(shaped: bool) -> RewardFnLoaderFn: def loader(path: str, venv: VecEnv) -> common.RewardFn: """Load train (shaped) or test (not shaped) reward from path.""" del venv # Unused. net = th.load(str(path))
"""Load serialized reward functions of different types.""" import contextlib from typing import Callable, ContextManager, Iterator import numpy as np from stable_baselines.common.vec_env import VecEnv from imitation.rewards import discrim_net, reward_net from imitation.util import registry, util from imitation.util.reward_wrapper import RewardFn RewardLoaderFn = Callable[[str, VecEnv], ContextManager[RewardFn]] RewardNetLoaderFn = Callable[[str, VecEnv], reward_net.RewardNet] reward_net_registry: registry.Registry[RewardNetLoaderFn] = registry.Registry() reward_fn_registry: registry.Registry[RewardLoaderFn] = registry.Registry() def _add_reward_net_loaders(classes): for name, cls in classes.items(): loader = registry.build_loader_fn_require_path(cls.load) reward_net_registry.register(key=name, value=loader) REWARD_NETS = { "BasicRewardNet": reward_net.BasicRewardNet, "BasicShapedRewardNet": reward_net.BasicShapedRewardNet, } _add_reward_net_loaders(REWARD_NETS)