Exemple #1
0
def __get_untransformed_vars(model=None):
    """
    Get list of non-transformed RVs in the model.
    """
    model = pm.modelcontext(model)

    var_names = [str(var) for var in model.vars]
    for i, string in enumerate(var_names):
        if is_transformed_name(string):
            var_names[i] = get_untransformed_name(string)

    return [getattr(model, f"{name}") for name in var_names]
Exemple #2
0
    def __init__(self, model_vars, values=None, model=None, rng=None):
        """Initialize a `TransMatConjugateStep` object."""

        model = pm.modelcontext(model)

        if isinstance(model_vars, Variable):
            model_vars = [model_vars]

        model_vars = list(chain.from_iterable([pm.inputvars(v) for v in model_vars]))

        # TODO: Are the rows in this matrix our `dir_priors`?
        dir_priors = []
        self.dir_priors_untrans = []
        for d in model_vars:
            untrans_var = model.named_vars[get_untransformed_name(d.name)]
            if isinstance(untrans_var.distribution, pm.Dirichlet):
                self.dir_priors_untrans.append(untrans_var)
                dir_priors.append(d)

        state_seqs = [
            v
            for v in model.vars + model.observed_RVs
            if isinstance(v.distribution, DiscreteMarkovChain)
            and all(d in graph_inputs([v.distribution.Gammas]) for d in dir_priors)
        ]

        if not self.dir_priors_untrans or not len(state_seqs) == 1:
            raise ValueError(
                "This step method requires a set of Dirichlet priors"
                " that comprise a single transition matrix"
            )

        (state_seq,) = state_seqs

        Gamma = state_seq.distribution.Gammas

        self._set_row_mappings(Gamma, dir_priors, model)

        if len(self.row_remaps) != len(dir_priors):
            raise TypeError(
                "The Dirichlet priors could not be found"
                " in the graph for {}".format(state_seq.distribution.Gammas)
            )

        if state_seq in model.observed_RVs:
            self.state_seq_obs = np.asarray(state_seq.distribution.data)

        self.rng = rng
        self.dists = list(dir_priors)
        self.state_seq_name = state_seq.name

        super().__init__(dir_priors, [], allvars=True)
Exemple #3
0
    def step(model):
        point = {}
        for var in model.basic_RVs:
            var_name = var.name
            if hasattr(var, 'distribution'):
                if is_transformed_name(var_name):
                    val = var.distribution.dist.random(
                        point=point, size=size)
                    var_name_untranf = get_untransformed_name(var_name)
                    point[var_name_untranf] = val
                    val = var.distribution.transform_used.forward_val(val)
                else:
                    val = var.distribution.random(point=point, size=size)
            else:
                nn, _, _ = get_named_nodes_and_relations(var)
                val = var.eval({model.named_vars[v]: point[v]
                                for v in nn})

            point[var_name] = val
        return point
Exemple #4
0
def start_optimizer(vars, verbose=True, progress_bar=True, **kwargs):
    if verbose:
        names = [
            get_untransformed_name(v.name)
            if is_transformed_name(v.name) else v.name for v in vars
        ]
        sys.stderr.write("optimizing logp for variables: [{0}]\n".format(
            ", ".join(names)))

        if progress_bar is True:
            if "EXOPLANET_NO_AUTO_PBAR" in os.environ:
                from tqdm import tqdm
            else:
                from tqdm.auto import tqdm

            progress_bar = tqdm(**kwargs)

    # Check whether the input progress bar has the expected methods
    has_progress_bar = (hasattr(progress_bar, "set_postfix")
                        and hasattr(progress_bar, "update")
                        and hasattr(progress_bar, "close"))
    return has_progress_bar, progress_bar
Exemple #5
0
def optimize(start=None,
             vars=None,
             model=None,
             return_info=False,
             verbose=True,
             **kwargs):
    """Maximize the log prob of a PyMC3 model using scipy

    All extra arguments are passed directly to the ``scipy.optimize.minimize``
    function.

    Args:
        start: The PyMC3 coordinate dictionary of the starting position
        vars: The variables to optimize
        model: The PyMC3 model
        return_info: Return both the coordinate dictionary and the result of
            ``scipy.optimize.minimize``
        verbose: Print the success flag and log probability to the screen

    """
    from scipy.optimize import minimize

    model = pm.modelcontext(model)

    # Work out the full starting coordinates
    if start is None:
        start = model.test_point
    else:
        update_start_vals(start, model.test_point, model)

    # Fit all the parameters by default
    if vars is None:
        vars = model.cont_vars
    vars = inputvars(vars)
    allinmodel(vars, model)

    # Work out the relevant bijection map
    start = Point(start, model=model)
    bij = DictToArrayBijection(ArrayOrdering(vars), start)

    # Pre-compile the theano model and gradient
    nlp = -model.logpt
    grad = theano.grad(nlp, vars, disconnected_inputs="ignore")
    func = get_theano_function_for_var([nlp] + grad, model=model)

    if verbose:
        names = [
            get_untransformed_name(v.name)
            if is_transformed_name(v.name) else v.name for v in vars
        ]
        sys.stderr.write("optimizing logp for variables: [{0}]\n".format(
            ", ".join(names)))
        bar = tqdm.tqdm()

    # This returns the objective function and its derivatives
    def objective(vec):
        res = func(*get_args_for_theano_function(bij.rmap(vec), model=model))
        d = dict(zip((v.name for v in vars), res[1:]))
        g = bij.map(d)
        if verbose:
            bar.set_postfix(logp="{0:e}".format(-res[0]))
            bar.update()
        return res[0], g

    # Optimize using scipy.optimize
    x0 = bij.map(start)
    initial = objective(x0)[0]
    kwargs["jac"] = True
    info = minimize(objective, x0, **kwargs)

    # Only accept the output if it is better than it was
    x = info.x if (np.isfinite(info.fun) and info.fun < initial) else x0

    # Coerce the output into the right format
    vars = get_default_varnames(model.unobserved_RVs, True)
    point = {
        var.name: value
        for var, value in zip(vars,
                              model.fastfn(vars)(bij.rmap(x)))
    }

    if verbose:
        bar.close()
        sys.stderr.write("message: {0}\n".format(info.message))
        sys.stderr.write("logp: {0} -> {1}\n".format(-initial, -info.fun))
        if not np.isfinite(info.fun):
            logger.warning("final logp not finite, returning initial point")
            logger.warning(
                "this suggests that something is wrong with the model")
            logger.debug("{0}".format(info))

    if return_info:
        return point, info
    return point
Exemple #6
0
    def _run_convergence_checks(self, idata: arviz.InferenceData, model):
        if not hasattr(idata, "posterior"):
            msg = "No posterior samples. Unable to run convergence checks"
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None,
                                  None, None)
            self._add_warnings([warn])
            return

        if idata.posterior.sizes["chain"] == 1:
            msg = ("Only one chain was sampled, this makes it impossible to "
                   "run some convergence checks")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
            self._add_warnings([warn])
            return

        valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
        varnames = []
        for rv in model.free_RVs:
            rv_name = rv.name
            if is_transformed_name(rv_name):
                rv_name2 = get_untransformed_name(rv_name)
                rv_name = rv_name2 if rv_name2 in valid_name else rv_name
            if rv_name in idata.posterior:
                varnames.append(rv_name)

        self._ess = ess = arviz.ess(idata, var_names=varnames)
        self._rhat = rhat = arviz.rhat(idata, var_names=varnames)

        warnings = []
        rhat_max = max(val.max() for val in rhat.values())
        if rhat_max > 1.4:
            msg = ("The rhat statistic is larger than 1.4 for some "
                   "parameters. The sampler did not converge.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "error",
                                  extra=rhat)
            warnings.append(warn)
        elif rhat_max > 1.2:
            msg = "The rhat statistic is larger than 1.2 for some " "parameters."
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "warn",
                                  extra=rhat)
            warnings.append(warn)
        elif rhat_max > 1.05:
            msg = ("The rhat statistic is larger than 1.05 for some "
                   "parameters. This indicates slight problems during "
                   "sampling.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "info",
                                  extra=rhat)
            warnings.append(warn)

        eff_min = min(val.min() for val in ess.values())
        sizes = idata.posterior.sizes
        n_samples = sizes["chain"] * sizes["draw"]
        if eff_min < 200 and n_samples >= 500:
            msg = ("The estimated number of effective samples is smaller than "
                   "200 for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "error",
                                  extra=ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.1:
            msg = "The number of effective samples is smaller than " "10% for some parameters."
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "warn",
                                  extra=ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.25:
            msg = "The number of effective samples is smaller than " "25% for some parameters."
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "info",
                                  extra=ess)
            warnings.append(warn)

        self._add_warnings(warnings)