Ejemplo n.º 1
0
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    # demonstrate how to use different HMC initialization strategies
    if args.init_strategy == "value":
        init_strategy = init_to_value(values={
            "kernel_var": 1.0,
            "kernel_noise": 0.05,
            "kernel_length": 0.5
        })
    elif args.init_strategy == "median":
        init_strategy = init_to_median(num_samples=10)
    elif args.init_strategy == "feasible":
        init_strategy = init_to_feasible()
    elif args.init_strategy == "sample":
        init_strategy = init_to_sample()
    elif args.init_strategy == "uniform":
        init_strategy = init_to_uniform(radius=1)
    kernel = NUTS(model, init_strategy=init_strategy)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
Ejemplo n.º 2
0
    def init_mcmc(
        self,
        model,
        num_warmup: int = 1000,
        num_samples: int = 1000,
        num_chains: int = 1,
        sampler="NUTS",
        sampler_kwargs={},
        **kwargs,
    ) -> MCMC:
        """Initialises the MCMC sampler.

        Args:
            model (callable): [desc]
            num_warmup (int): [description]. Defaults to 1000.
            num_samples (int): [description]. Defaults to 1000.
            num_chains (int): [description]. Defaults to 1.
            sampler (str, or numpyro.infer.mcmc.MCMCKernel): Choose one of
                ['NUTS'], or pass a numpyro mcmc kernel.
            sampler_kwargs (dict): Keyword arguments to pass to the chosen
                sampler.
            **kwargs: Keyword arguments to pass to mcmc instance.

        """
        self._mcmc_support_warnings()

        if isinstance(sampler, str):
            sampler = sampler.lower()
            if sampler != "nuts":
                raise ValueError(f"Sampler '{sampler}' not supported.")
            target_accept_prob = sampler_kwargs.pop("target_accept_prob", 0.98)
            init_strategy = sampler_kwargs.pop(
                "init_strategy",
                lambda site=None: init_to_median(site=site, num_samples=100),
            )
            step_size = sampler_kwargs.pop("step_size", 0.1)
            sampler = NUTS(
                model,
                target_accept_prob=target_accept_prob,
                init_strategy=init_strategy,
                step_size=step_size,
                **sampler_kwargs,
            )

        # if num_chains > 1:
        # self.batch_ndims = 2  # I.e. two dims for chains then samples

        return MCMC(
            sampler,
            num_warmup=num_warmup,
            num_samples=num_samples,
            num_chains=num_chains,
            **kwargs,
        )
Ejemplo n.º 3
0
def init_to_mean(site=None):
    """
    Initialize to the prior mean; fallback to median if mean is undefined.
    """
    if site is None:
        return partial(init_to_mean)

    try:
        # Try .mean() method.
        if site['type'] == 'sample' and not site['is_observed'] and not site[
                'fn'].is_discrete:
            value = site["fn"].mean
            # if jnp.isnan(value):
            #    raise ValueError
            if hasattr(site["fn"], "_validate_sample"):
                site["fn"]._validate_sample(value)
            return np.array(value)
    except (NotImplementedError, ValueError):
        # Fall back to a median.
        # This is required for distributions with infinite variance, e.g. Cauchy.
        return init_to_median(site)