Example #1
0
def test_dynamic_supports():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000,))

    def actual_model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(
                    dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)
                ),
            )
        with numpyro.plate("N", len(data)):
            numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    def expected_model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        loc = numpyro.sample("loc", dist.Uniform(0, 1)) * alpha
        with numpyro.plate("N", len(data)):
            numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)

    guide = AutoDiagonalNormal(actual_model)
    svi = SVI(actual_model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data)
    actual_opt_params = adam.get_params(svi_state.optim_state)
    actual_params = svi.get_params(svi_state)
    actual_values = guide.median(actual_params)
    actual_loss = svi.evaluate(svi_state, data)

    guide = AutoDiagonalNormal(expected_model)
    svi = SVI(expected_model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data)
    expected_opt_params = adam.get_params(svi_state.optim_state)
    expected_params = svi.get_params(svi_state)
    expected_values = guide.median(expected_params)
    expected_loss = svi.evaluate(svi_state, data)

    # test auto_loc, auto_scale
    check_eq(actual_opt_params, expected_opt_params)
    check_eq(actual_params, expected_params)
    # test latent values
    assert_allclose(actual_values["alpha"], expected_values["alpha"])
    assert_allclose(actual_values["loc_base"], expected_values["loc"])
    assert_allclose(actual_loss, expected_loss)
def test_elbo_dynamic_support():
    x_prior = dist.Uniform(0, 5)
    x_unconstrained = 2.

    def model():
        numpyro.sample('x', x_prior)

    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {'_auto_latent': x_unconstrained})(*args, **kwargs)

    adam = optim.Adam(0.01)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)

    guide_log_prob = dist.Normal(
        guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum()
    transfrom = transforms.biject_to(constraints.interval(0, 5))
    x = transfrom(x_unconstrained)
    logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x)
    model_log_prob = x_prior.log_prob(x) + logdet
    expected_loss = guide_log_prob - model_log_prob
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Example #3
0
def test_elbo_dynamic_support():
    x_prior = dist.TransformedDistribution(
        dist.Normal(),
        [AffineTransform(0, 2),
         SigmoidTransform(),
         AffineTransform(0, 3)])
    x_guide = dist.Uniform(0, 3)

    def model():
        numpyro.sample('x', x_prior)

    def guide():
        numpyro.sample('x', x_guide)

    adam = optim.Adam(0.01)
    # set base value of x_guide is 0.9
    x_base = 0.9
    guide = substitute(guide, base_param_map={'x': x_base})
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)
    x, _ = x_guide.transform_with_intermediates(x_base)
    expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
    assert_allclose(actual_loss, expected_loss)
def test_dynamic_supports():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def actual_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    def expected_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)

    guide = AutoDiagonalNormal(actual_model)
    svi = SVI(actual_model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key_init, data)
    actual_opt_params = adam.get_params(svi_state.optim_state)
    actual_params = svi.get_params(svi_state)
    actual_values = guide.median(actual_params)
    actual_loss = svi.evaluate(svi_state, data)

    guide = AutoDiagonalNormal(expected_model)
    svi = SVI(expected_model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key_init, data)
    expected_opt_params = adam.get_params(svi_state.optim_state)
    expected_params = svi.get_params(svi_state)
    expected_values = guide.median(expected_params)
    expected_loss = svi.evaluate(svi_state, data)

    # test auto_loc, auto_scale
    check_eq(actual_opt_params, expected_opt_params)
    check_eq(actual_params, expected_params)
    # test latent values
    assert_allclose(actual_values['alpha'], expected_values['alpha'])
    assert_allclose(actual_values['loc'],
                    expected_values['alpha'] * expected_values['loc'])
    assert_allclose(actual_loss, expected_loss)
Example #5
0
def test_param():
    # this test the validity of model/guide sites having
    # param constraints contain composed transformed
    rng_keys = random.split(random.PRNGKey(0), 5)
    a_minval = 1
    c_minval = -2
    c_maxval = -1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
    d_init = random.uniform(rng_keys[3])
    obs = random.normal(rng_keys[4])

    def model():
        a = numpyro.param('a',
                          a_init,
                          constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b), obs=obs)

    def guide():
        c = numpyro.param('c',
                          c_init,
                          constraint=constraints.interval(c_minval, c_maxval))
        d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
        numpyro.sample('y', dist.Normal(c, d), obs=obs)

    adam = optim.Adam(0.01)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['c'], c_init)
    assert_allclose(params['d'], d_init)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(
        a_init, b_init).log_prob(obs)
    # not so precisely because we do transform / inverse transform stuffs
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Example #6
0
def test_param():
    # this test the validity of model having
    # param sites contain composed transformed constraints
    rng_keys = random.split(random.PRNGKey(0), 3)
    a_minval = 1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    x_init = random.normal(rng_keys[2])

    def model():
        a = numpyro.param("a",
                          a_init,
                          constraint=constraints.greater_than(a_minval))
        b = numpyro.param("b", b_init, constraint=constraints.positive)
        numpyro.sample("x", dist.Normal(a, b))

    # this class is used to force init value of `x` to x_init
    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {"_auto_latent": x_init[None]})(*args, **kwargs)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init)

    params = svi.get_params(svi_state)
    assert_allclose(params["a"], a_init, rtol=1e-6)
    assert_allclose(params["b"], b_init, rtol=1e-6)
    assert_allclose(params["auto_loc"], guide._init_latent, rtol=1e-6)
    assert_allclose(params["auto_scale"],
                    jnp.ones(1) * guide._init_scale,
                    rtol=1e-6)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(
        guide._init_latent, guide._init_scale).log_prob(x_init) - dist.Normal(
            a_init, b_init).log_prob(x_init)
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Example #7
0
def test_elbo_dynamic_support():
    x_prior = dist.TransformedDistribution(
        dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)])
    x_guide = dist.Uniform(0, 3)

    def model():
        numpyro.sample('x', x_prior)

    def guide():
        numpyro.sample('x', x_guide)

    adam = optim.Adam(0.01)
    x = 2.
    guide = substitute(guide, data={'x': x})
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
    assert_allclose(actual_loss, expected_loss)
Example #8
0
class ModelHandler(object):
    def __init__(self,
                 model: Model,
                 guide: Guide,
                 rng_key: int = 0,
                 *,
                 loss: ELBO = ELBO(num_particles=1),
                 optim_builder: optim.optimizers.optimizer = optim.Adam):
        """Handling the model and guide for training and prediction

        Args:
            model: function holding the numpyro model
            guide: function holding the numpyro guide
            rng_key: random key as int
            loss: loss to optimize
            optim_builder: builder for an optimizer
        """
        self.model = model
        self.guide = guide
        self.rng_key = random.PRNGKey(rng_key)  # current random key
        self.loss = loss
        self.optim_builder = optim_builder
        self.svi = None
        self.svi_state = None
        self.optim = None
        self.log_func = print  # overwrite e.g. logger.info(...)

    def reset_svi(self):
        """Reset the current SVI state"""
        self.svi = None
        self.svi_state = None
        return self

    def init_svi(self, X: DeviceArray, *, lr: float, **kwargs):
        """Initialize the SVI state

        Args:
            X: input data
            lr: learning rate
            kwargs: other keyword arguments for optimizer
        """
        self.optim = self.optim_builder(lr, **kwargs)
        self.svi = SVI(self.model, self.guide, self.optim, self.loss)
        svi_state = self.svi.init(self.rng_key, X)
        if self.svi_state is None:
            self.svi_state = svi_state
        return self

    @property
    def optim_state(self) -> OptimizerState:
        """Current optimizer state"""
        assert self.svi_state is not None, "'init_svi' needs to be called first"
        return self.svi_state.optim_state

    @optim_state.setter
    def optim_state(self, state: OptimizerState):
        """Set current optimizer state"""
        self.svi_state = SVIState(state, self.rng_key)

    def dump_optim_state(self, fh: IO):
        """Pickle and dump optimizer state to file handle"""
        pickle.dump(
            optim.optimizers.unpack_optimizer_state(self.optim_state[1]), fh)
        return self

    def load_optim_state(self, fh: IO):
        """Read and unpickle optimizer state from file handle"""
        state = optim.optimizers.pack_optimizer_state(pickle.load(fh))
        iter0 = jnp.array(0)
        self.optim_state = (iter0, state)
        return self

    @property
    def optim_total_steps(self) -> int:
        """Returns the number of performed iterations in total"""
        return int(self.optim_state[0])

    def _fit(self, X: DeviceArray, n_epochs) -> float:
        @jit
        def train_epochs(svi_state, n_epochs):
            def train_one_epoch(_, val):
                loss, svi_state = val
                svi_state, loss = self.svi.update(svi_state, X)
                return loss, svi_state

            return lax.fori_loop(0, n_epochs, train_one_epoch, (0., svi_state))

        loss, self.svi_state = train_epochs(self.svi_state, n_epochs)
        return float(loss / X.shape[0])

    def _log(self, n_digits, epoch, loss):
        msg = f"epoch: {str(epoch).rjust(n_digits)} loss: {loss: 16.4f}"
        self.log_func(msg)

    def fit(self,
            X: DeviceArray,
            *,
            n_epochs: int,
            log_freq: int = 0,
            lr: float,
            **kwargs) -> float:
        """Train but log with a given frequency

        Args:
            X: input data
            n_epochs: total number of epochs
            log_freq: log loss every log_freq number of eppochs
            lr: learning rate
            kwargs: parameters of `init_svi`

        Returns:
            final loss of last epoch
        """
        self.init_svi(X, lr=lr, **kwargs)
        if log_freq <= 0:
            self._fit(X, n_epochs)
        else:
            loss = self.svi.evaluate(self.svi_state, X) / X.shape[0]

            curr_epoch = 0
            n_digits = len(str(abs(n_epochs)))
            self._log(n_digits, curr_epoch, loss)

            for i in range(n_epochs // log_freq):
                curr_epoch += log_freq
                loss = self._fit(X, log_freq)
                self._log(n_digits, curr_epoch, loss)

            rest = n_epochs % log_freq
            if rest > 0:
                curr_epoch += rest

                loss = self._fit(X, rest)
                self._log(n_digits, curr_epoch, loss)

        loss = self.svi.evaluate(self.svi_state, X) / X.shape[0]
        self.rng_key = self.svi_state.rng_key
        return float(loss)

    @property
    def model_params(self) -> Optional[Dict[str, DeviceArray]]:
        """Gets model parameters

        Returns:
            dict of model parameters
        """
        if self.svi is not None:
            return self.svi.get_params(self.svi_state)
        else:
            return None

    def predict(self, X: DeviceArray, **kwargs) -> DeviceArray:
        """Predict the parameters of a model specified by `return_sites`

        Args:
            X: input data
            kwargs: keyword arguments for numpro `Predictive`

        Returns:
            samples for all sample sites
        """
        self.init_svi(X, lr=0.)  # dummy initialization
        predictive = Predictive(self.model,
                                guide=self.guide,
                                params=self.model_params,
                                **kwargs)
        samples = predictive(self.rng_key, X)
        return samples