def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity()) model = handlers.seed(self.model, rng_key) self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(*args, **kwargs) self._args = args self._kwargs = kwargs
def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix), dist.PRNGIdentity()) init_params, _ = handlers.block(find_valid_initial_params)( rng_key, self.model, init_strategy=self.init_strategy, model_args=args, model_kwargs=kwargs) self._inv_transforms = {} self._has_transformed_dist = False unconstrained_sites = {} for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: if site['intermediates']: transform = biject_to(site['fn'].base_dist.support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv( site['intermediates'][0][0]) self._has_transformed_dist = True else: transform = biject_to(site['fn'].support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv(site['value']) self._init_latent, self._unpack_latent = ravel_pytree(init_params) self.latent_size = np.size(self._init_latent) if self.base_dist is None: self.base_dist = dist.Independent( dist.Normal(np.zeros(self.latent_size), 1.), 1) if self.latent_size == 0: raise RuntimeError( '{} found no latent variables; Use an empty guide instead'. format(type(self).__name__))
def run(self, *args, rng_key=None, **kwargs): if rng_key is None: rng_key = numpyro.sample('mcmc.run', dist.PRNGIdentity()) self._mcmc.run(rng_key, *args, init_params=self._initial_params, **kwargs)
def step(self, *args, rng_key=None, **kwargs): if self.svi_state is None: if rng_key is None: rng_key = numpyro.sample('svi.init', dist.PRNGIdentity()) self.svi_state = self.init(rng_key, *args, **kwargs) try: self.svi_state, loss = jit(self.update)(self.svi_state, *args, **kwargs) except TypeError as e: if 'not a valid JAX type' in str(e): raise TypeError( 'NumPyro backend requires args, kwargs to be arrays or tuples, ' 'dicts of arrays.') else: raise e params = jit(super(SVI, self).get_params)(self.svi_state) get_param_store().update(params) return loss
def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity()) with handlers.block(): init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( rng_key, self.model, init_strategy=self.init_strategy, dynamic_args=False, model_args=args, model_kwargs=kwargs) self._init_latent, unpack_latent = ravel_pytree(init_params[0]) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_latent = UnpackTransform(unpack_latent) self.latent_dim = jnp.size(self._init_latent) if self.latent_dim == 0: raise RuntimeError('{} found no latent variables; Use an empty guide instead' .format(type(self).__name__))
def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity()) with handlers.block(): init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( rng_key, self.model, init_strategy=self.init_loc_fn, dynamic_args=False, model_args=args, model_kwargs=kwargs) self._init_locs = init_params[0] self._prototype_frames = {} self._prototype_plate_sizes = {} for name, site in self.prototype_trace.items(): if site["type"] == "sample": for frame in site["cond_indep_stack"]: self._prototype_frames[frame.name] = frame elif site["type"] == "plate": self._prototype_frame_full_sizes[name] = site["args"][0]
def _setup_prototype(self, *args, **kwargs): super(AutoDelta, self)._setup_prototype(*args, **kwargs) rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix), dist.PRNGIdentity()) self.find_params(rng_key, *args, **kwargs)