def __init__(self, model, guide, optim, loss, **static_kwargs): self.model = model self.guide = guide self.loss = loss self.static_kwargs = static_kwargs self.constrain_fn = None if isinstance(optim, _NumPyroOptim): self.optim = optim elif isinstance(optim, optimizers.Optimizer): self.optim = _NumPyroOptim(lambda *args: args, *optim) else: try: import optax except ImportError: raise ImportError( "It looks like you tried to use an optimizer that isn't an " "instance of numpyro.optim._NumPyroOptim or " "jax.example_libraries.optimizers.Optimizer. There is experimental " "support for Optax optimizers, but you need to install Optax. " "It can be installed with `pip install optax`.") if not isinstance(optim, optax.GradientTransformation): raise TypeError( "Expected either an instance of numpyro.optim._NumPyroOptim, " "jax.example_libraries.optimizers.Optimizer or " "optax.GradientTransformation. Got {}".format(type(optim))) self.optim = optax_to_numpyro(optim)
def optax_to_numpyro( transformation: optax.GradientTransformation) -> _NumPyroOptim: """ This function produces a ``numpyro.optim._NumPyroOptim`` instance from an ``optax.GradientTransformation`` so that it can be used with ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the ``(init_fn, update_fn, get_params_fn)`` interface defined by :mod:`jax.experimental.optimizers`. :param transformation: An ``optax.GradientTransformation`` instance to wrap. :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied Optax optimizer. """ def init_fn(params: _Params) -> _State: opt_state = transformation.init(params) return params, opt_state def update_fn(step, grads: _Params, state: _State) -> _State: params, opt_state = state updates, opt_state = transformation.update(grads, opt_state, params) updated_params = optax.apply_updates(params, updates) return updated_params, opt_state def get_params_fn(state: _State) -> _Params: params, _ = state return params return _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)