def run_inference(model, args, rng_key, X, Y): start = time.time() # demonstrate how to use different HMC initialization strategies if args.init_strategy == "value": init_strategy = init_to_value(values={ "kernel_var": 1.0, "kernel_noise": 0.05, "kernel_length": 0.5 }) elif args.init_strategy == "median": init_strategy = init_to_median(num_samples=10) elif args.init_strategy == "feasible": init_strategy = init_to_feasible() elif args.init_strategy == "sample": init_strategy = init_to_sample() elif args.init_strategy == "uniform": init_strategy = init_to_uniform(radius=1) kernel = NUTS(model, init_strategy=init_strategy) mcmc = MCMC( kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, X, Y) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples()
def init_mcmc( self, model, num_warmup: int = 1000, num_samples: int = 1000, num_chains: int = 1, sampler="NUTS", sampler_kwargs={}, **kwargs, ) -> MCMC: """Initialises the MCMC sampler. Args: model (callable): [desc] num_warmup (int): [description]. Defaults to 1000. num_samples (int): [description]. Defaults to 1000. num_chains (int): [description]. Defaults to 1. sampler (str, or numpyro.infer.mcmc.MCMCKernel): Choose one of ['NUTS'], or pass a numpyro mcmc kernel. sampler_kwargs (dict): Keyword arguments to pass to the chosen sampler. **kwargs: Keyword arguments to pass to mcmc instance. """ self._mcmc_support_warnings() if isinstance(sampler, str): sampler = sampler.lower() if sampler != "nuts": raise ValueError(f"Sampler '{sampler}' not supported.") target_accept_prob = sampler_kwargs.pop("target_accept_prob", 0.98) init_strategy = sampler_kwargs.pop( "init_strategy", lambda site=None: init_to_median(site=site, num_samples=100), ) step_size = sampler_kwargs.pop("step_size", 0.1) sampler = NUTS( model, target_accept_prob=target_accept_prob, init_strategy=init_strategy, step_size=step_size, **sampler_kwargs, ) # if num_chains > 1: # self.batch_ndims = 2 # I.e. two dims for chains then samples return MCMC( sampler, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, **kwargs, )
def init_to_mean(site=None): """ Initialize to the prior mean; fallback to median if mean is undefined. """ if site is None: return partial(init_to_mean) try: # Try .mean() method. if site['type'] == 'sample' and not site['is_observed'] and not site[ 'fn'].is_discrete: value = site["fn"].mean # if jnp.isnan(value): # raise ValueError if hasattr(site["fn"], "_validate_sample"): site["fn"]._validate_sample(value) return np.array(value) except (NotImplementedError, ValueError): # Fall back to a median. # This is required for distributions with infinite variance, e.g. Cauchy. return init_to_median(site)