Exemplo n.º 1
0
    def step(self, point: PointType):

        partial_funcs_and_point = [
            DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
        ]
        if self.allvars:
            partial_funcs_and_point.append(point)

        apoint = DictToArrayBijection.map(
            {v.name: point[v.name]
             for v in self.vars})
        step_res = self.astep(apoint, *partial_funcs_and_point)

        if self.generates_stats:
            apoint_new, stats = step_res
        else:
            apoint_new = step_res

        if not isinstance(apoint_new, RaveledVars):
            # We assume that the mapping has stayed the same
            apoint_new = RaveledVars(apoint_new, apoint.point_map_info)

        point_new = DictToArrayBijection.rmap(apoint_new, start_point=point)

        if self.generates_stats:
            return point_new, stats

        return point_new
Exemplo n.º 2
0
 def dlogp_func(x):
     return DictToArrayBijection.mapf(model.fastdlogp_nojac(vars))(
         RaveledVars(x, x0.point_map_info))
Exemplo n.º 3
0
def find_MAP(
    start=None,
    vars=None,
    method="L-BFGS-B",
    return_raw=False,
    include_transformed=True,
    progressbar=True,
    maxeval=5000,
    model=None,
    *args,
    seed: Optional[int] = None,
    **kwargs
):
    """Finds the local maximum a posteriori point given a model.

    `find_MAP` should not be used to initialize the NUTS sampler. Simply call
    ``pymc.sample()`` and it will automatically initialize NUTS in a better
    way.

    Parameters
    ----------
    start: `dict` of parameter values (Defaults to `model.initial_point`)
    vars: list
        List of variables to optimize and set to optimum (Defaults to all continuous).
    method: string or callable
        Optimization algorithm (Defaults to 'L-BFGS-B' unless
        discrete variables are specified in `vars`, then
        `Powell` which will perform better).  For instructions on use of a callable,
        refer to SciPy's documentation of `optimize.minimize`.
    return_raw: bool
        Whether to return the full output of scipy.optimize.minimize (Defaults to `False`)
    include_transformed: bool, optional defaults to True
        Flag for reporting automatically transformed variables in addition
        to original variables.
    progressbar: bool, optional defaults to True
        Whether or not to display a progress bar in the command line.
    maxeval: int, optional, defaults to 5000
        The maximum number of times the posterior distribution is evaluated.
    model: Model (optional if in `with` context)
    *args, **kwargs
        Extra args passed to scipy.optimize.minimize

    Notes
    -----
    Older code examples used `find_MAP` to initialize the NUTS sampler,
    but this is not an effective way of choosing starting values for sampling.
    As a result, we have greatly enhanced the initialization of NUTS and
    wrapped it inside ``pymc.sample()`` and you should thus avoid this method.
    """
    model = modelcontext(model)

    if vars is None:
        vars = model.cont_vars
        if not vars:
            raise ValueError("Model has no unobserved continuous variables.")
    else:
        vars = [model.rvs_to_values.get(var, var) for var in vars]

    vars = inputvars(vars)
    disc_vars = list(typefilter(vars, discrete_types))
    allinmodel(vars, model)
    ipfn = make_initial_point_fn(
        model=model,
        jitter_rvs=set(),
        return_transformed=True,
        overrides=start,
    )
    start = ipfn(seed)
    model.check_start_vals(start)

    var_names = {var.name for var in vars}
    x0 = DictToArrayBijection.map(
        {var_name: value for var_name, value in start.items() if var_name in var_names}
    )

    # TODO: If the mapping is fixed, we can simply create graphs for the
    # mapping and avoid all this bijection overhead
    compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False), start)
    logp_func = lambda x: compiled_logp_func(RaveledVars(x, x0.point_map_info))

    rvs = [model.values_to_rvs[value] for value in vars]
    try:
        # This might be needed for calls to `dlogp_func`
        # start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars)
        compiled_dlogp_func = DictToArrayBijection.mapf(
            model.compile_dlogp(rvs, jacobian=False), start
        )
        dlogp_func = lambda x: compiled_dlogp_func(RaveledVars(x, x0.point_map_info))
        compute_gradient = True
    except (AttributeError, NotImplementedError, tg.NullTypeGradError):
        compute_gradient = False

    if disc_vars or not compute_gradient:
        pm._log.warning(
            "Warning: gradient not available."
            + "(E.g. vars contains discrete variables). MAP "
            + "estimates may not be accurate for the default "
            + "parameters. Defaulting to non-gradient minimization "
            + "'Powell'."
        )
        method = "Powell"

    if compute_gradient:
        cost_func = CostFuncWrapper(maxeval, progressbar, logp_func, dlogp_func)
    else:
        cost_func = CostFuncWrapper(maxeval, progressbar, logp_func)

    try:
        opt_result = minimize(
            cost_func, x0.data, method=method, jac=compute_gradient, *args, **kwargs
        )
        mx0 = opt_result["x"]  # r -> opt_result
    except (KeyboardInterrupt, StopIteration) as e:
        mx0, opt_result = cost_func.previous_x, None
        if isinstance(e, StopIteration):
            pm._log.info(e)
    finally:
        last_v = cost_func.n_eval
        if progressbar:
            assert isinstance(cost_func.progress, ProgressBar)
            cost_func.progress.total = last_v
            cost_func.progress.update(last_v)
            print(file=sys.stdout)

    mx0 = RaveledVars(mx0, x0.point_map_info)
    unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
    unobserved_vars_values = model.compile_fn(unobserved_vars)(
        DictToArrayBijection.rmap(mx0, start)
    )
    mx = {var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)}

    if return_raw:
        return mx, opt_result
    else:
        return mx