def __init__(self, model, guide, optim, loss, **static_kwargs): self.model = model self.guide = guide self.loss = loss self.static_kwargs = static_kwargs self.constrain_fn = None if isinstance(optim, _NumPyroOptim): self.optim = optim elif isinstance(optim, optimizers.Optimizer): self.optim = _NumPyroOptim(lambda *args: args, *optim) else: try: import optax from numpyro.contrib.optim import optax_to_numpyro except ImportError: raise ImportError( "It looks like you tried to use an optimizer that isn't an " "instance of numpyro.optim._NumPyroOptim or " "jax.example_libraries.optimizers.Optimizer. There is experimental " "support for Optax optimizers, but you need to install Optax. " "It can be installed with `pip install optax`.") if not isinstance(optim, optax.GradientTransformation): raise TypeError( "Expected either an instance of numpyro.optim._NumPyroOptim, " "jax.example_libraries.optimizers.Optimizer or " "optax.GradientTransformation. Got {}".format(type(optim))) self.optim = optax_to_numpyro(optim)
def test_optim_multi_params(optim_class, args, kwargs): params = { "x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([-1, -1.0, -1.0]) } opt = optax_to_numpyro(optim_class(*args, **kwargs)) opt_state = opt.init(params) for i in range(2000): opt_state = step(opt_state, opt) for _, param in opt.get_params(opt_state).items(): assert jnp.allclose(param, jnp.zeros(3))