コード例 #1
0
ファイル: sampling.py プロジェクト: hstm/pymc3
def init_nuts(init='advi', n_init=500000, model=None, **kwargs):
    """Initialize and sample from posterior of a continuous model.

    This is a convenience function. NUTS convergence and sampling speed is extremely
    dependent on the choice of mass/scaling matrix. In our experience, using ADVI
    to estimate a diagonal covariance matrix and using this as the scaling matrix
    produces robust results over a wide class of continuous models.

    Parameters
    ----------
    init : str {'advi', 'advi_map', 'map', 'nuts'}
        Initialization method to use.
        * advi : Run ADVI to estimate posterior mean and diagonal covariance matrix.
        * advi_map: Initialize ADVI with MAP and use MAP as starting point.
        * map : Use the MAP as starting point.
        * nuts : Run NUTS and estimate posterior mean and covariance matrix.
    n_init : int
        Number of iterations of initializer
        If 'advi', number of iterations, if 'metropolis', number of draws.
    model : Model (optional if in `with` context)
    **kwargs : keyword arguments
        Extra keyword arguments are forwarded to pymc3.NUTS.

    Returns
    -------
    start, nuts_sampler

    start : pymc3.model.Point
        Starting point for sampler
    nuts_sampler : pymc3.step_methods.NUTS
        Instantiated and initialized NUTS sampler object
    """

    model = pm.modelcontext(model)

    pm._log.info('Initializing NUTS using {}...'.format(init))

    if init == 'advi':
        v_params = pm.variational.advi(n=n_init)
        start = pm.variational.sample_vp(v_params, 1, progressbar=False)[0]
        cov = np.power(model.dict_to_array(v_params.stds), 2)
    elif init == 'advi_map':
        start = pm.find_MAP()
        v_params = pm.variational.advi(n=n_init, start=start)
        cov = np.power(model.dict_to_array(v_params.stds), 2)
    elif init == 'map':
        start = pm.find_MAP()
        cov = pm.find_hessian(point=start)

    elif init == 'nuts':
        init_trace = pm.sample(step=pm.NUTS(), draws=n_init)
        cov = pm.trace_cov(init_trace[n_init//2:])

        start = {varname: np.mean(init_trace[varname]) for varname in init_trace.varnames}
    else:
        raise NotImplemented('Initializer {} is not supported.'.format(init))

    step = pm.NUTS(scaling=cov, is_cov=True, **kwargs)

    return start, step
コード例 #2
0
def test_hh_flow():
    cov = pm.floatX([[2, -1], [-1, 3]])
    with pm.Model():
        pm.MvNormal('mvN', mu=pm.floatX([0, 1]), cov=cov, shape=2)
        nf = NFVI('scale-hh*2-loc')
        nf.fit(25000, obj_optimizer=pm.adam(learning_rate=0.001))
        trace = nf.approx.sample(10000)
        cov2 = pm.trace_cov(trace)
    np.testing.assert_allclose(cov, cov2, rtol=0.07)
コード例 #3
0
def test_hh_flow():
    cov = pm.floatX([[2, -1], [-1, 3]])
    with pm.Model():
        pm.MvNormal('mvN', mu=pm.floatX([0, 1]), cov=cov, shape=2)
        nf = NFVI('scale-hh*2-loc')
        nf.fit(25000, obj_optimizer=pm.adam(learning_rate=0.001))
        trace = nf.approx.sample(10000)
        cov2 = pm.trace_cov(trace)
    np.testing.assert_allclose(cov, cov2, rtol=0.07)
コード例 #4
0
def init_nuts(init='advi', n_init=500000, model=None):
    """Initialize and sample from posterior of a continuous model.

    This is a convenience function. NUTS convergence and sampling speed is extremely
    dependent on the choice of mass/scaling matrix. In our experience, using ADVI
    to estimate a diagonal covariance matrix and using this as the scaling matrix
    produces robust results over a wide class of continuous models.

    Parameters
    ----------
    init : str {'advi', 'advi_map', 'map', 'nuts'}
        Initialization method to use.
        * advi : Run ADVI to estimate posterior mean and diagonal covariance matrix.
        * advi_map: Initialize ADVI with MAP and use MAP as starting point.
        * map : Use the MAP as starting point.
        * nuts : Run NUTS and estimate posterior mean and covariance matrix.
    n_init : int
        Number of iterations of initializer
        If 'advi', number of iterations, if 'metropolis', number of draws.
    model : Model (optional if in `with` context)

    Returns
    -------
    start, nuts_sampler

    start : pymc3.model.Point
        Starting point for sampler
    nuts_sampler : pymc3.step_methods.NUTS
        Instantiated and initialized NUTS sampler object
    """

    model = pm.modelcontext(model)

    pm._log.info('Initializing NUTS using {}...'.format(init))

    if init == 'advi':
        v_params = pm.variational.advi(n=n_init)
        start = pm.variational.sample_vp(v_params, 1, progressbar=False)[0]
        cov = np.power(model.dict_to_array(v_params.stds), 2)
    elif init == 'advi_map':
        start = pm.find_MAP()
        v_params = pm.variational.advi(n=n_init, start=start)
        cov = np.power(model.dict_to_array(v_params.stds), 2)
    elif init == 'map':
        start = pm.find_MAP()
        cov = pm.find_hessian(point=start)

    elif init == 'nuts':
        init_trace = pm.sample(step=pm.NUTS(), draws=n_init)
        cov = pm.trace_cov(init_trace[n_init // 2:])

        start = {
            varname: np.mean(init_trace[varname])
            for varname in init_trace.varnames
        }
    else:
        raise NotImplemented('Initializer {} is not supported.'.format(init))

    step = pm.NUTS(scaling=cov, is_cov=True)

    return start, step
コード例 #5
0
def init_nuts(init='ADVI',
              njobs=1,
              n_init=500000,
              model=None,
              random_seed=-1,
              **kwargs):
    """Initialize and sample from posterior of a continuous model.

    This is a convenience function. NUTS convergence and sampling speed is extremely
    dependent on the choice of mass/scaling matrix. In our experience, using ADVI
    to estimate a diagonal covariance matrix and using this as the scaling matrix
    produces robust results over a wide class of continuous models.

    Parameters
    ----------
    init : str {'ADVI', 'ADVI_MAP', 'MAP', 'NUTS'}
        Initialization method to use.
        * ADVI : Run ADVI to estimate posterior mean and diagonal covariance matrix.
        * ADVI_MAP: Initialize ADVI with MAP and use MAP as starting point.
        * MAP : Use the MAP as starting point.
        * NUTS : Run NUTS and estimate posterior mean and covariance matrix.
    njobs : int
        Number of parallel jobs to start.
    n_init : int
        Number of iterations of initializer
        If 'ADVI', number of iterations, if 'metropolis', number of draws.
    model : Model (optional if in `with` context)
    **kwargs : keyword arguments
        Extra keyword arguments are forwarded to pymc3.NUTS.

    Returns
    -------
    start, nuts_sampler

    start : pymc3.model.Point
        Starting point for sampler
    nuts_sampler : pymc3.step_methods.NUTS
        Instantiated and initialized NUTS sampler object
    """

    model = pm.modelcontext(model)

    pm._log.info('Initializing NUTS using {}...'.format(init))

    random_seed = int(np.atleast_1d(random_seed)[0])

    if init is not None:
        init = init.lower()

    if init == 'advi':
        v_params = pm.variational.advi(n=n_init, random_seed=random_seed)
        start = pm.variational.sample_vp(v_params,
                                         njobs,
                                         progressbar=False,
                                         hide_transformed=False,
                                         random_seed=random_seed)
        if njobs == 1:
            start = start[0]
        cov = np.power(model.dict_to_array(v_params.stds), 2)
    elif init == 'advi_map':
        start = pm.find_MAP()
        v_params = pm.variational.advi(n=n_init,
                                       start=start,
                                       random_seed=random_seed)
        cov = np.power(model.dict_to_array(v_params.stds), 2)
    elif init == 'map':
        start = pm.find_MAP()
        cov = pm.find_hessian(point=start)
    elif init == 'nuts':
        init_trace = pm.sample(step=pm.NUTS(),
                               draws=n_init,
                               random_seed=random_seed)[n_init // 2:]
        cov = np.atleast_1d(pm.trace_cov(init_trace))
        start = np.random.choice(init_trace, njobs)
        if njobs == 1:
            start = start[0]
    else:
        raise NotImplementedError(
            'Initializer {} is not supported.'.format(init))

    step = pm.NUTS(scaling=cov, is_cov=True, **kwargs)

    return start, step
コード例 #6
0
ファイル: sampling.py プロジェクト: hanbman/pymc3
def init_nuts(init='ADVI',
              njobs=1,
              n_init=500000,
              model=None,
              random_seed=-1,
              progressbar=True,
              **kwargs):
    """Initialize and sample from posterior of a continuous model.

    This is a convenience function. NUTS convergence and sampling speed is extremely
    dependent on the choice of mass/scaling matrix. In our experience, using ADVI
    to estimate a diagonal covariance matrix and using this as the scaling matrix
    produces robust results over a wide class of continuous models.

    Parameters
    ----------
    init : str {'ADVI', 'ADVI_MAP', 'MAP', 'NUTS'}
        Initialization method to use.
        * ADVI : Run ADVI to estimate posterior mean and diagonal covariance matrix.
        * ADVI_MAP: Initialize ADVI with MAP and use MAP as starting point.
        * MAP : Use the MAP as starting point.
        * NUTS : Run NUTS and estimate posterior mean and covariance matrix.
    njobs : int
        Number of parallel jobs to start.
    n_init : int
        Number of iterations of initializer
        If 'ADVI', number of iterations, if 'metropolis', number of draws.
    model : Model (optional if in `with` context)
    progressbar : bool
        Whether or not to display a progressbar for advi sampling.
    **kwargs : keyword arguments
        Extra keyword arguments are forwarded to pymc3.NUTS.

    Returns
    -------
    start : pymc3.model.Point
        Starting point for sampler
    nuts_sampler : pymc3.step_methods.NUTS
        Instantiated and initialized NUTS sampler object
    """

    model = pm.modelcontext(model)

    pm._log.info('Initializing NUTS using {}...'.format(init))

    random_seed = int(np.atleast_1d(random_seed)[0])

    if init is not None:
        init = init.lower()
    cb = [
        pm.callbacks.CheckParametersConvergence(tolerance=1e-2,
                                                diff='absolute'),
        pm.callbacks.CheckParametersConvergence(tolerance=1e-2,
                                                diff='relative'),
    ]
    if init == 'advi':
        approx = pm.fit(random_seed=random_seed,
                        n=n_init,
                        method='advi',
                        model=model,
                        callbacks=cb,
                        progressbar=progressbar,
                        obj_optimizer=pm.adagrad_window)  # type: pm.MeanField
        start = approx.sample(draws=njobs)
        stds = approx.gbij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        if njobs == 1:
            start = start[0]
    elif init == 'advi_map':
        start = pm.find_MAP()
        approx = pm.MeanField(model=model, start=start)
        pm.fit(random_seed=random_seed,
               n=n_init,
               method=pm.ADVI.from_mean_field(approx),
               callbacks=cb,
               progressbar=progressbar,
               obj_optimizer=pm.adagrad_window)
        start = approx.sample(draws=njobs)
        stds = approx.gbij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        if njobs == 1:
            start = start[0]
    elif init == 'map':
        start = pm.find_MAP()
        cov = pm.find_hessian(point=start)
    elif init == 'nuts':
        init_trace = pm.sample(draws=n_init,
                               step=pm.NUTS(),
                               tune=n_init // 2,
                               random_seed=random_seed)
        cov = np.atleast_1d(pm.trace_cov(init_trace))
        start = np.random.choice(init_trace, njobs)
        if njobs == 1:
            start = start[0]
    else:
        raise NotImplementedError(
            'Initializer {} is not supported.'.format(init))

    step = pm.NUTS(scaling=cov, is_cov=True, **kwargs)

    return start, step
コード例 #7
0
 def train_model(self, method, prior, res=22.):
     test_model = self.get_model(prior, test=True)
     with test_model:
         # True step size is scaled down
         # by (1/n)^(1/4), so that a res of 20º -> 15º sampling in 3-d
         # therefore, we scale UP by (1/n)^(1/4) so when we specify
         # 20º, we get 20º!
         step_scale = res * np.pi / 180 / (1 / self.n_d)**0.25
         if method == 'MH':
             step = pm.Metropolis()
             test_trace = pm.sample(step=step,
                                    cores=1,
                                    chains=1,
                                    draws=1,
                                    tune=8000,
                                    discard_tuned_samples=False)
             cov = pm.trace_cov(test_trace)
             self.write_(cov, method, prior)
             return
         elif method == 'HMC':
             step = pm.HamiltonianMC(step_scale=step_scale)
             test_trace = pm.sample(step=step,
                                    cores=1,
                                    chains=1,
                                    draws=1,
                                    tune=8000,
                                    discard_tuned_samples=False)
             cov = pm.trace_cov(test_trace)
             self.write_(cov, method, prior)
             return
         elif method == 'NUTS':
             step = pm.NUTS(step_scale=step_scale,
                            adapt_step_size=False,
                            target_accept=0.8)
             test_trace = pm.sample(step=step,
                                    cores=1,
                                    chains=1,
                                    draws=1,
                                    tune=8000,
                                    discard_tuned_samples=False)
             cov = pm.trace_cov(test_trace)
             accept = test_trace.get_sampler_stats("mean_tree_accept")
             print("Mean acceptance for method", method, "and prior", prior,
                   "is", accept.mean())
             print("Step size is",
                   test_trace.get_sampler_stats("step_size_bar"))
             t_a = accept.mean() if accept.mean() > 0.5 else 0.8
     training_model = self.get_model(prior)
     # Method is NUTS
     # NUTS tuning will be a little more intense:
     with training_model:
         step = pm.NUTS(scaling=cov,
                        is_cov=True,
                        step_scale=step_scale,
                        adapt_step_size=False,
                        target_accept=t_a)
         train_trace = pm.sample(step=step,
                                 cores=1,
                                 chains=1,
                                 draws=1,
                                 tune=500,
                                 discard_tuned_samples=False)
         cov = pm.trace_cov(train_trace)
     self.write_(cov, method, prior)
コード例 #8
0
    # print(f"Map Estimate:{map_estimate}")

    times_consumed = []

    N = 100
    alphas = norm.rvs(size=N)
    betas = halfnorm.rvs(size=(N, 2))
    sigmas = halfnorm.rvs(size=N)

    print("=========================================")
    print(f"Tuning NUTS")
    print("=========================================")

    n_chains = 4
    init_trace = pm.sample(draws=1000, tune=1000, cores=n_chains)
    cov = np.atleast_1d(pm.trace_cov(init_trace))
    start = list(np.random.choice(init_trace, n_chains))
    potential = quadpotential.QuadPotentialFull(cov)
    step_size = init_trace.get_sampler_stats("step_size_bar")[-1]
    size = m.bijection.ordering.size
    step_scale = step_size * (size**0.25)

    # with pm.Model() as model_new:  # reset model. If you use theano.shared you can also update the value of model1 above
    for i in range(N):
        # No-U-Turn Sampler NUTS
        print("=========================================")
        print(f"Turn {i}")
        print("=========================================")
        # start = {"alpha": alphas[i], "beta": betas[i], "sigma": sigmas[i]}
        time_zero = default_timer()
        step = pm.NUTS(potential=potential,
コード例 #9
0
ファイル: sampling.py プロジェクト: DanielWeitzenfeld/pymc3
def init_nuts(init='auto',
              chains=1,
              n_init=500000,
              model=None,
              random_seed=None,
              progressbar=True,
              **kwargs):
    """Set up the mass matrix initialization for NUTS.

    NUTS convergence and sampling speed is extremely dependent on the
    choice of mass/scaling matrix. This function implements different
    methods for choosing or adapting the mass matrix.

    Parameters
    ----------
    init : str
        Initialization method to use.

        * auto : Choose a default initialization method automatically.
          Currently, this is `'jitter+adapt_diag'`, but this can change in
          the future. If you depend on the exact behaviour, choose an
          initialization method explicitly.
        * adapt_diag : Start with a identity mass matrix and then adapt
          a diagonal based on the variance of the tuning samples. All
          chains use the test value (usually the prior mean) as starting
          point.
        * jitter+adapt_diag : Same as `adapt_diag`, but add uniform jitter
          in [-1, 1] to the starting point in each chain.
        * advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
          mass matrix based on the sample variance of the tuning samples.
        * advi+adapt_diag_grad : Run ADVI and then adapt the resulting
          diagonal mass matrix based on the variance of the gradients
          during tuning. This is **experimental** and might be removed
          in a future release.
        * advi : Run ADVI to estimate posterior mean and diagonal mass
          matrix.
        * advi_map: Initialize ADVI with MAP and use MAP as starting point.
        * map : Use the MAP as starting point. This is discouraged.
        * nuts : Run NUTS and estimate posterior mean and mass matrix from
          the trace.
    chains : int
        Number of jobs to start.
    n_init : int
        Number of iterations of initializer
        If 'ADVI', number of iterations, if 'nuts', number of draws.
    model : Model (optional if in `with` context)
    progressbar : bool
        Whether or not to display a progressbar for advi sampling.
    **kwargs : keyword arguments
        Extra keyword arguments are forwarded to pymc3.NUTS.

    Returns
    -------
    start : pymc3.model.Point
        Starting point for sampler
    nuts_sampler : pymc3.step_methods.NUTS
        Instantiated and initialized NUTS sampler object
    """
    model = pm.modelcontext(model)

    vars = kwargs.get('vars', model.vars)
    if set(vars) != set(model.vars):
        raise ValueError('Must use init_nuts on all variables of a model.')
    if not pm.model.all_continuous(vars):
        raise ValueError('init_nuts can only be used for models with only '
                         'continuous variables.')

    if not isinstance(init, str):
        raise TypeError('init must be a string.')

    if init is not None:
        init = init.lower()

    if init == 'auto':
        init = 'jitter+adapt_diag'

    pm._log.info('Initializing NUTS using {}...'.format(init))

    if random_seed is not None:
        random_seed = int(np.atleast_1d(random_seed)[0])
        np.random.seed(random_seed)

    cb = [
        pm.callbacks.CheckParametersConvergence(tolerance=1e-2,
                                                diff='absolute'),
        pm.callbacks.CheckParametersConvergence(tolerance=1e-2,
                                                diff='relative'),
    ]

    if init == 'adapt_diag':
        start = [model.test_point] * chains
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)
    elif init == 'jitter+adapt_diag':
        start = []
        for _ in range(chains):
            mean = {var: val.copy() for var, val in model.test_point.items()}
            for val in mean.values():
                val[...] += 2 * np.random.rand(*val.shape) - 1
            start.append(mean)
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)
    elif init == 'advi+adapt_diag_grad':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init,
            method='advi',
            model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window,
        )  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        mean = approx.bij.rmap(approx.mean.get_value())
        mean = model.dict_to_array(mean)
        weight = 50
        potential = quadpotential.QuadPotentialDiagAdaptGrad(
            model.ndim, mean, cov, weight)
    elif init == 'advi+adapt_diag':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init,
            method='advi',
            model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window,
        )  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        mean = approx.bij.rmap(approx.mean.get_value())
        mean = model.dict_to_array(mean)
        weight = 50
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, cov, weight)
    elif init == 'advi':
        approx = pm.fit(random_seed=random_seed,
                        n=n_init,
                        method='advi',
                        model=model,
                        callbacks=cb,
                        progressbar=progressbar,
                        obj_optimizer=pm.adagrad_window)  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        potential = quadpotential.QuadPotentialDiag(cov)
    elif init == 'advi_map':
        start = pm.find_MAP(include_transformed=True)
        approx = pm.MeanField(model=model, start=start)
        pm.fit(random_seed=random_seed,
               n=n_init,
               method=pm.KLqp(approx),
               callbacks=cb,
               progressbar=progressbar,
               obj_optimizer=pm.adagrad_window)
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        potential = quadpotential.QuadPotentialDiag(cov)
    elif init == 'map':
        start = pm.find_MAP(include_transformed=True)
        cov = pm.find_hessian(point=start)
        start = [start] * chains
        potential = quadpotential.QuadPotentialFull(cov)
    elif init == 'nuts':
        init_trace = pm.sample(draws=n_init,
                               step=pm.NUTS(),
                               tune=n_init // 2,
                               random_seed=random_seed)
        cov = np.atleast_1d(pm.trace_cov(init_trace))
        start = list(np.random.choice(init_trace, chains))
        potential = quadpotential.QuadPotentialFull(cov)
    else:
        raise NotImplementedError(
            'Initializer {} is not supported.'.format(init))

    step = pm.NUTS(potential=potential, **kwargs)

    return start, step
コード例 #10
0
ファイル: sampling.py プロジェクト: springcoil/pymc3
def init_nuts(init='auto', njobs=1, n_init=500000, model=None,
              random_seed=-1, progressbar=True, **kwargs):
    """Set up the mass matrix initialization for NUTS.

    NUTS convergence and sampling speed is extremely dependent on the
    choice of mass/scaling matrix. This function implements different
    methods for choosing or adapting the mass matrix.

    Parameters
    ----------
    init : str
        Initialization method to use.

        * auto : Choose a default initialization method automatically.
          Currently, this is `'jitter+adapt_diag'`, but this can change in
          the future. If you depend on the exact behaviour, choose an
          initialization method explicitly.
        * adapt_diag : Start with a identity mass matrix and then adapt
          a diagonal based on the variance of the tuning samples. All
          chains use the test value (usually the prior mean) as starting
          point.
        * jitter+adapt_diag : Same as `adapt_diag`, but add uniform jitter
          in [-1, 1] to the starting point in each chain.
        * advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
          mass matrix based on the sample variance of the tuning samples.
        * advi+adapt_diag_grad : Run ADVI and then adapt the resulting
          diagonal mass matrix based on the variance of the gradients
          during tuning. This is **experimental** and might be removed
          in a future release.
        * advi : Run ADVI to estimate posterior mean and diagonal mass
          matrix.
        * advi_map: Initialize ADVI with MAP and use MAP as starting point.
        * map : Use the MAP as starting point. This is discouraged.
        * nuts : Run NUTS and estimate posterior mean and mass matrix from
          the trace.
    njobs : int
        Number of parallel jobs to start.
    n_init : int
        Number of iterations of initializer
        If 'ADVI', number of iterations, if 'nuts', number of draws.
    model : Model (optional if in `with` context)
    progressbar : bool
        Whether or not to display a progressbar for advi sampling.
    **kwargs : keyword arguments
        Extra keyword arguments are forwarded to pymc3.NUTS.

    Returns
    -------
    start : pymc3.model.Point
        Starting point for sampler
    nuts_sampler : pymc3.step_methods.NUTS
        Instantiated and initialized NUTS sampler object
    """
    model = pm.modelcontext(model)

    vars = kwargs.get('vars', model.vars)
    if set(vars) != set(model.vars):
        raise ValueError('Must use init_nuts on all variables of a model.')
    if not pm.model.all_continuous(vars):
        raise ValueError('init_nuts can only be used for models with only '
                         'continuous variables.')

    if not isinstance(init, str):
        raise TypeError('init must be a string.')

    if init is not None:
        init = init.lower()

    if init == 'auto':
        init = 'jitter+adapt_diag'

    pm._log.info('Initializing NUTS using {}...'.format(init))

    random_seed = int(np.atleast_1d(random_seed)[0])

    cb = [
        pm.callbacks.CheckParametersConvergence(
            tolerance=1e-2, diff='absolute'),
        pm.callbacks.CheckParametersConvergence(
            tolerance=1e-2, diff='relative'),
    ]

    if init == 'adapt_diag':
        start = [model.test_point] * njobs
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)
        if njobs == 1:
            start = start[0]
    elif init == 'jitter+adapt_diag':
        start = []
        for _ in range(njobs):
            mean = {var: val.copy() for var, val in model.test_point.items()}
            for val in mean.values():
                val[...] += 2 * np.random.rand(*val.shape) - 1
            start.append(mean)
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)
        if njobs == 1:
            start = start[0]
    elif init == 'advi+adapt_diag_grad':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init, method='advi', model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window,
        )  # type: pm.MeanField
        start = approx.sample(draws=njobs)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds) ** 2
        mean = approx.bij.rmap(approx.mean.get_value())
        mean = model.dict_to_array(mean)
        weight = 50
        potential = quadpotential.QuadPotentialDiagAdaptGrad(
            model.ndim, mean, cov, weight)
        if njobs == 1:
            start = start[0]
    elif init == 'advi+adapt_diag':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init, method='advi', model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window,
        )  # type: pm.MeanField
        start = approx.sample(draws=njobs)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds) ** 2
        mean = approx.bij.rmap(approx.mean.get_value())
        mean = model.dict_to_array(mean)
        weight = 50
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, cov, weight)
        if njobs == 1:
            start = start[0]
    elif init == 'advi':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init, method='advi', model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window
        )  # type: pm.MeanField
        start = approx.sample(draws=njobs)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds) ** 2
        potential = quadpotential.QuadPotentialDiag(cov)
        if njobs == 1:
            start = start[0]
    elif init == 'advi_map':
        start = pm.find_MAP()
        approx = pm.MeanField(model=model, start=start)
        pm.fit(
            random_seed=random_seed,
            n=n_init, method=pm.KLqp(approx),
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window
        )
        start = approx.sample(draws=njobs)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds) ** 2
        potential = quadpotential.QuadPotentialDiag(cov)
        if njobs == 1:
            start = start[0]
    elif init == 'map':
        start = pm.find_MAP()
        cov = pm.find_hessian(point=start)
        start = [start] * njobs
        potential = quadpotential.QuadPotentialFull(cov)
        if njobs == 1:
            start = start[0]
    elif init == 'nuts':
        init_trace = pm.sample(draws=n_init, step=pm.NUTS(),
                               tune=n_init // 2,
                               random_seed=random_seed)
        cov = np.atleast_1d(pm.trace_cov(init_trace))
        start = list(np.random.choice(init_trace, njobs))
        potential = quadpotential.QuadPotentialFull(cov)
        if njobs == 1:
            start = start[0]
    else:
        raise NotImplementedError('Initializer {} is not supported.'.format(init))

    step = pm.NUTS(potential=potential, **kwargs)

    return start, step
total_tools = shared(dk.total_tools.values)

# %%
with pm.Model() as m_10_10:
    a = pm.Normal("a", 0, 100)
    b = pm.Normal("b", 0, 1, shape=3)
    lam = pm.math.exp(a + b[0] * log_pop + b[1] * contact_high +
                      b[2] * contact_high * log_pop)
    obs = pm.Poisson("total_tools", lam, observed=total_tools)
    trace_10_10 = pm.sample(1000, tune=1000)

# %%
summary = az.summary(trace_10_10, credible_interval=0.89)[[
    "mean", "sd", "hpd_5.5%", "hpd_94.5%"
]]
trace_cov = pm.trace_cov(trace_10_10, model=m_10_10)
invD = (np.sqrt(np.diag(trace_cov))**-1)[:, None]
trace_corr = pd.DataFrame(invD * trace_cov * invD.T,
                          index=summary.index,
                          columns=summary.index)

summary.join(trace_corr).round(2)

# %%
az.plot_forest(trace_10_10)

# %%
lambda_high = np.exp(trace_10_10["a"] + trace_10_10["b"][:, 1] +
                     (trace_10_10["b"][:, 0] + trace_10_10["b"][:, 2]) * 8)
lambda_low = np.exp(trace_10_10["a"] + trace_10_10["b"][:, 0] * 8)