예제 #1
0
    def __init__(self, model, guide, optim, loss, **static_kwargs):
        self.model = model
        self.guide = guide
        self.loss = loss
        self.static_kwargs = static_kwargs
        self.constrain_fn = None

        if isinstance(optim, _NumPyroOptim):
            self.optim = optim
        elif isinstance(optim, optimizers.Optimizer):
            self.optim = _NumPyroOptim(lambda *args: args, *optim)
        else:
            try:
                import optax
            except ImportError:
                raise ImportError(
                    "It looks like you tried to use an optimizer that isn't an "
                    "instance of numpyro.optim._NumPyroOptim or "
                    "jax.example_libraries.optimizers.Optimizer. There is experimental "
                    "support for Optax optimizers, but you need to install Optax. "
                    "It can be installed with `pip install optax`.")

            if not isinstance(optim, optax.GradientTransformation):
                raise TypeError(
                    "Expected either an instance of numpyro.optim._NumPyroOptim, "
                    "jax.example_libraries.optimizers.Optimizer or "
                    "optax.GradientTransformation. Got {}".format(type(optim)))

            self.optim = optax_to_numpyro(optim)
예제 #2
0
파일: optim.py 프로젝트: fehiepsi/numpyro
def optax_to_numpyro(
        transformation: optax.GradientTransformation) -> _NumPyroOptim:
    """
    This function produces a ``numpyro.optim._NumPyroOptim`` instance from an
    ``optax.GradientTransformation`` so that it can be used with
    ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the
    ``(init_fn, update_fn, get_params_fn)`` interface defined by
    :mod:`jax.experimental.optimizers`.

    :param transformation: An ``optax.GradientTransformation`` instance to wrap.
    :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied
        Optax optimizer.
    """
    def init_fn(params: _Params) -> _State:
        opt_state = transformation.init(params)
        return params, opt_state

    def update_fn(step, grads: _Params, state: _State) -> _State:
        params, opt_state = state
        updates, opt_state = transformation.update(grads, opt_state, params)
        updated_params = optax.apply_updates(params, updates)
        return updated_params, opt_state

    def get_params_fn(state: _State) -> _Params:
        params, _ = state
        return params

    return _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn,
                         get_params_fn)