Exemplo n.º 1
0
    def __init__(self, model, guide, optim, loss, **static_kwargs):
        self.model = model
        self.guide = guide
        self.loss = loss
        self.static_kwargs = static_kwargs
        self.constrain_fn = None

        if isinstance(optim, _NumPyroOptim):
            self.optim = optim
        elif isinstance(optim, optimizers.Optimizer):
            self.optim = _NumPyroOptim(lambda *args: args, *optim)
        else:
            try:
                import optax

                from numpyro.contrib.optim import optax_to_numpyro
            except ImportError:
                raise ImportError(
                    "It looks like you tried to use an optimizer that isn't an "
                    "instance of numpyro.optim._NumPyroOptim or "
                    "jax.example_libraries.optimizers.Optimizer. There is experimental "
                    "support for Optax optimizers, but you need to install Optax. "
                    "It can be installed with `pip install optax`.")

            if not isinstance(optim, optax.GradientTransformation):
                raise TypeError(
                    "Expected either an instance of numpyro.optim._NumPyroOptim, "
                    "jax.example_libraries.optimizers.Optimizer or "
                    "optax.GradientTransformation. Got {}".format(type(optim)))

            self.optim = optax_to_numpyro(optim)
Exemplo n.º 2
0
def test_optim_multi_params(optim_class, args, kwargs):
    params = {
        "x": jnp.array([1.0, 1.0, 1.0]),
        "y": jnp.array([-1, -1.0, -1.0])
    }
    opt = optax_to_numpyro(optim_class(*args, **kwargs))
    opt_state = opt.init(params)
    for i in range(2000):
        opt_state = step(opt_state, opt)
    for _, param in opt.get_params(opt_state).items():
        assert jnp.allclose(param, jnp.zeros(3))