def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000,)) def actual_model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha) ), ) with numpyro.plate("N", len(data)): numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) loc = numpyro.sample("loc", dist.Uniform(0, 1)) * alpha with numpyro.plate("N", len(data)): numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(actual_model) svi = SVI(actual_model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) actual_opt_params = adam.get_params(svi_state.optim_state) actual_params = svi.get_params(svi_state) actual_values = guide.median(actual_params) actual_loss = svi.evaluate(svi_state, data) guide = AutoDiagonalNormal(expected_model) svi = SVI(expected_model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data) expected_opt_params = adam.get_params(svi_state.optim_state) expected_params = svi.get_params(svi_state) expected_values = guide.median(expected_params) expected_loss = svi.evaluate(svi_state, data) # test auto_loc, auto_scale check_eq(actual_opt_params, expected_opt_params) check_eq(actual_params, expected_params) # test latent values assert_allclose(actual_values["alpha"], expected_values["alpha"]) assert_allclose(actual_values["loc_base"], expected_values["loc"]) assert_allclose(actual_loss, expected_loss)
def test_elbo_dynamic_support(): x_prior = dist.Uniform(0, 5) x_unconstrained = 2. def model(): numpyro.sample('x', x_prior) class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {'_auto_latent': x_unconstrained})(*args, **kwargs) adam = optim.Adam(0.01) guide = _AutoGuide(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) guide_log_prob = dist.Normal( guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum() transfrom = transforms.biject_to(constraints.interval(0, 5)) x = transfrom(x_unconstrained) logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x) model_log_prob = x_prior.log_prob(x) + logdet expected_loss = guide_log_prob - model_log_prob assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)]) x_guide = dist.Uniform(0, 3) def model(): numpyro.sample('x', x_prior) def guide(): numpyro.sample('x', x_guide) adam = optim.Adam(0.01) # set base value of x_guide is 0.9 x_base = 0.9 guide = substitute(guide, base_param_map={'x': x_base}) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) x, _ = x_guide.transform_with_intermediates(x_base) expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x) assert_allclose(actual_loss, expected_loss)
def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def actual_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(actual_model) svi = SVI(actual_model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key_init, data) actual_opt_params = adam.get_params(svi_state.optim_state) actual_params = svi.get_params(svi_state) actual_values = guide.median(actual_params) actual_loss = svi.evaluate(svi_state, data) guide = AutoDiagonalNormal(expected_model) svi = SVI(expected_model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(rng_key_init, data) expected_opt_params = adam.get_params(svi_state.optim_state) expected_params = svi.get_params(svi_state) expected_values = guide.median(expected_params) expected_loss = svi.evaluate(svi_state, data) # test auto_loc, auto_scale check_eq(actual_opt_params, expected_opt_params) check_eq(actual_params, expected_params) # test latent values assert_allclose(actual_values['alpha'], expected_values['alpha']) assert_allclose(actual_values['loc'], expected_values['alpha'] * expected_values['loc']) assert_allclose(actual_loss, expected_loss)
def test_param(): # this test the validity of model/guide sites having # param constraints contain composed transformed rng_keys = random.split(random.PRNGKey(0), 5) a_minval = 1 c_minval = -2 c_maxval = -1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval) d_init = random.uniform(rng_keys[3]) obs = random.normal(rng_keys[4]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b), obs=obs) def guide(): c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval)) d = numpyro.param('d', d_init, constraint=constraints.unit_interval) numpyro.sample('y', dist.Normal(c, d), obs=obs) adam = optim.Adam(0.01) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['c'], c_init) assert_allclose(params['d'], d_init) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal( a_init, b_init).log_prob(obs) # not so precisely because we do transform / inverse transform stuffs assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_param(): # this test the validity of model having # param sites contain composed transformed constraints rng_keys = random.split(random.PRNGKey(0), 3) a_minval = 1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) x_init = random.normal(rng_keys[2]) def model(): a = numpyro.param("a", a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param("b", b_init, constraint=constraints.positive) numpyro.sample("x", dist.Normal(a, b)) # this class is used to force init value of `x` to x_init class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {"_auto_latent": x_init[None]})(*args, **kwargs) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = _AutoGuide(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init) params = svi.get_params(svi_state) assert_allclose(params["a"], a_init, rtol=1e-6) assert_allclose(params["b"], b_init, rtol=1e-6) assert_allclose(params["auto_loc"], guide._init_latent, rtol=1e-6) assert_allclose(params["auto_scale"], jnp.ones(1) * guide._init_scale, rtol=1e-6) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal( guide._init_latent, guide._init_scale).log_prob(x_init) - dist.Normal( a_init, b_init).log_prob(x_init) assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)]) x_guide = dist.Uniform(0, 3) def model(): numpyro.sample('x', x_prior) def guide(): numpyro.sample('x', x_guide) adam = optim.Adam(0.01) x = 2. guide = substitute(guide, data={'x': x}) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x) assert_allclose(actual_loss, expected_loss)
class ModelHandler(object): def __init__(self, model: Model, guide: Guide, rng_key: int = 0, *, loss: ELBO = ELBO(num_particles=1), optim_builder: optim.optimizers.optimizer = optim.Adam): """Handling the model and guide for training and prediction Args: model: function holding the numpyro model guide: function holding the numpyro guide rng_key: random key as int loss: loss to optimize optim_builder: builder for an optimizer """ self.model = model self.guide = guide self.rng_key = random.PRNGKey(rng_key) # current random key self.loss = loss self.optim_builder = optim_builder self.svi = None self.svi_state = None self.optim = None self.log_func = print # overwrite e.g. logger.info(...) def reset_svi(self): """Reset the current SVI state""" self.svi = None self.svi_state = None return self def init_svi(self, X: DeviceArray, *, lr: float, **kwargs): """Initialize the SVI state Args: X: input data lr: learning rate kwargs: other keyword arguments for optimizer """ self.optim = self.optim_builder(lr, **kwargs) self.svi = SVI(self.model, self.guide, self.optim, self.loss) svi_state = self.svi.init(self.rng_key, X) if self.svi_state is None: self.svi_state = svi_state return self @property def optim_state(self) -> OptimizerState: """Current optimizer state""" assert self.svi_state is not None, "'init_svi' needs to be called first" return self.svi_state.optim_state @optim_state.setter def optim_state(self, state: OptimizerState): """Set current optimizer state""" self.svi_state = SVIState(state, self.rng_key) def dump_optim_state(self, fh: IO): """Pickle and dump optimizer state to file handle""" pickle.dump( optim.optimizers.unpack_optimizer_state(self.optim_state[1]), fh) return self def load_optim_state(self, fh: IO): """Read and unpickle optimizer state from file handle""" state = optim.optimizers.pack_optimizer_state(pickle.load(fh)) iter0 = jnp.array(0) self.optim_state = (iter0, state) return self @property def optim_total_steps(self) -> int: """Returns the number of performed iterations in total""" return int(self.optim_state[0]) def _fit(self, X: DeviceArray, n_epochs) -> float: @jit def train_epochs(svi_state, n_epochs): def train_one_epoch(_, val): loss, svi_state = val svi_state, loss = self.svi.update(svi_state, X) return loss, svi_state return lax.fori_loop(0, n_epochs, train_one_epoch, (0., svi_state)) loss, self.svi_state = train_epochs(self.svi_state, n_epochs) return float(loss / X.shape[0]) def _log(self, n_digits, epoch, loss): msg = f"epoch: {str(epoch).rjust(n_digits)} loss: {loss: 16.4f}" self.log_func(msg) def fit(self, X: DeviceArray, *, n_epochs: int, log_freq: int = 0, lr: float, **kwargs) -> float: """Train but log with a given frequency Args: X: input data n_epochs: total number of epochs log_freq: log loss every log_freq number of eppochs lr: learning rate kwargs: parameters of `init_svi` Returns: final loss of last epoch """ self.init_svi(X, lr=lr, **kwargs) if log_freq <= 0: self._fit(X, n_epochs) else: loss = self.svi.evaluate(self.svi_state, X) / X.shape[0] curr_epoch = 0 n_digits = len(str(abs(n_epochs))) self._log(n_digits, curr_epoch, loss) for i in range(n_epochs // log_freq): curr_epoch += log_freq loss = self._fit(X, log_freq) self._log(n_digits, curr_epoch, loss) rest = n_epochs % log_freq if rest > 0: curr_epoch += rest loss = self._fit(X, rest) self._log(n_digits, curr_epoch, loss) loss = self.svi.evaluate(self.svi_state, X) / X.shape[0] self.rng_key = self.svi_state.rng_key return float(loss) @property def model_params(self) -> Optional[Dict[str, DeviceArray]]: """Gets model parameters Returns: dict of model parameters """ if self.svi is not None: return self.svi.get_params(self.svi_state) else: return None def predict(self, X: DeviceArray, **kwargs) -> DeviceArray: """Predict the parameters of a model specified by `return_sites` Args: X: input data kwargs: keyword arguments for numpro `Predictive` Returns: samples for all sample sites """ self.init_svi(X, lr=0.) # dummy initialization predictive = Predictive(self.model, guide=self.guide, params=self.model_params, **kwargs) samples = predictive(self.rng_key, X) return samples