예제 #1
0
    def step(self, point):

        for name, shared_var in self.shared.items():
            shared_var.set_value(point[name])

        q = DictToArrayBijection.map(
            {v.name: point[v.name]
             for v in self.vars})

        step_res = self.astep(q)

        if self.generates_stats:
            apoint, stats = step_res
        else:
            apoint = step_res

        if not isinstance(apoint, RaveledVars):
            # We assume that the mapping has stayed the same
            apoint = RaveledVars(apoint, q.point_map_info)

        new_point = DictToArrayBijection.rmap(apoint, start_point=point)

        if self.generates_stats:
            return new_point, stats

        return new_point
예제 #2
0
    def step(self, point: PointType):

        partial_funcs_and_point = [
            DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
        ]
        if self.allvars:
            partial_funcs_and_point.append(point)

        apoint = DictToArrayBijection.map(
            {v.name: point[v.name]
             for v in self.vars})
        step_res = self.astep(apoint, *partial_funcs_and_point)

        if self.generates_stats:
            apoint_new, stats = step_res
        else:
            apoint_new = step_res

        if not isinstance(apoint_new, RaveledVars):
            # We assume that the mapping has stayed the same
            apoint_new = RaveledVars(apoint_new, apoint.point_map_info)

        point_new = DictToArrayBijection.rmap(apoint_new, start_point=point)

        if self.generates_stats:
            return point_new, stats

        return point_new
예제 #3
0
    def astep(self, q0):
        """Perform a single HMC iteration."""
        perf_start = time.perf_counter()
        process_start = time.process_time()

        p0 = self.potential.random()
        p0 = RaveledVars(p0, q0.point_map_info)

        start = self.integrator.compute_state(q0, p0)

        if not np.isfinite(start.energy):
            model = self._model
            check_test_point = model.point_logps()
            error_logp = check_test_point.loc[
                (np.abs(check_test_point) >= 1e20) | np.isnan(check_test_point)
            ]
            self.potential.raise_ok(q0.point_map_info)
            message_energy = (
                "Bad initial energy, check any log probabilities that "
                "are inf or -inf, nan or very small:\n{}".format(error_logp.to_string())
            )
            warning = SamplerWarning(
                WarningType.BAD_ENERGY,
                message_energy,
                "critical",
                self.iter_count,
            )
            self._warnings.append(warning)
            raise SamplingError("Bad initial energy")

        adapt_step = self.tune and self.adapt_step_size
        step_size = self.step_adapt.current(adapt_step)
        self.step_size = step_size

        if self._step_rand is not None:
            step_size = self._step_rand(step_size)

        hmc_step = self._hamiltonian_step(start, p0.data, step_size)

        perf_end = time.perf_counter()
        process_end = time.process_time()

        self.step_adapt.update(hmc_step.accept_stat, adapt_step)
        self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
        if hmc_step.divergence_info:
            info = hmc_step.divergence_info
            point = None
            point_dest = None
            info_store = None
            if self.tune:
                kind = WarningType.TUNING_DIVERGENCE
            else:
                kind = WarningType.DIVERGENCE
                self._num_divs_sample += 1
                # We don't want to fill up all memory with divergence info
                if self._num_divs_sample < 100 and info.state is not None:
                    point = DictToArrayBijection.rmap(info.state.q)

                if self._num_divs_sample < 100 and info.state_div is not None:
                    point = DictToArrayBijection.rmap(info.state_div.q)

                if self._num_divs_sample < 100:
                    info_store = info
            warning = SamplerWarning(
                kind,
                info.message,
                "debug",
                self.iter_count,
                info.exec_info,
                divergence_point_source=point,
                divergence_point_dest=point_dest,
                divergence_info=info_store,
            )

            self._warnings.append(warning)

        self.iter_count += 1
        if not self.tune:
            self._samples_after_tune += 1

        stats = {
            "tune": self.tune,
            "diverging": bool(hmc_step.divergence_info),
            "perf_counter_diff": perf_end - perf_start,
            "process_time_diff": process_end - process_start,
            "perf_counter_start": perf_start,
        }

        stats.update(hmc_step.stats)
        stats.update(self.step_adapt.stats())

        return hmc_step.end.q, [stats]
예제 #4
0
파일: mlda.py 프로젝트: t-triobox/pymc3
    def __call__(self, q0: RaveledVars) -> RaveledVars:
        """Returns proposed sample given the current sample
        in dictionary form (q0_dict)."""

        # Logging is reduced to avoid extensive console output
        # during multiple recursive calls of subsample()
        _log = logging.getLogger("pymc")
        _log.setLevel(logging.ERROR)

        # Convert current sample from RaveledVars ->
        # dict before feeding to subsample.
        q0_dict = DictToArrayBijection.rmap(q0)

        with self.model_below:
            # Check if the tuning flag has been set to False
            # in which case tuning is stopped. The flag is set
            # to False (by MLDA's astep) when the burn-in
            # iterations of the highest-level MLDA sampler run out.
            # The change propagates to all levels.

            if self.tune:
                # Subsample in tuning mode
                trace = subsample(
                    draws=0,
                    step=self.step_method_below,
                    start=q0_dict,
                    tune=self.subsampling_rate,
                )
            else:
                # Subsample in normal mode without tuning
                # If DEMetropolisZMLDA is the base sampler a flag is raised to
                # make sure that history is edited after tuning ends
                if self.tuning_end_trigger:
                    if isinstance(self.step_method_below, DEMetropolisZMLDA):
                        self.step_method_below.tuning_end_trigger = True
                    self.tuning_end_trigger = False

                trace = subsample(
                    draws=self.subsampling_rate,
                    step=self.step_method_below,
                    start=q0_dict,
                    tune=0,
                )

        # set logging back to normal
        _log.setLevel(logging.NOTSET)

        # return sample with index self.subchain_selection from the generated
        # sequence of length self.subsampling_rate. The index is set within
        # MLDA's astep() function
        q_dict = trace.point(self.subchain_selection)

        # Make sure output dict is ordered the same way as the input dict.
        q_dict = Point(
            {key: q_dict[key]
             for key in q0_dict.keys()},
            model=self.model_below,
            filter_model_vars=True,
        )

        return DictToArrayBijection.map(q_dict)
예제 #5
0
def find_MAP(start=None,
             vars=None,
             method="L-BFGS-B",
             return_raw=False,
             include_transformed=True,
             progressbar=True,
             maxeval=5000,
             model=None,
             *args,
             seed: Optional[int] = None,
             **kwargs):
    """Finds the local maximum a posteriori point given a model.

    `find_MAP` should not be used to initialize the NUTS sampler. Simply call
    ``pymc.sample()`` and it will automatically initialize NUTS in a better
    way.

    Parameters
    ----------
    start: `dict` of parameter values (Defaults to `model.initial_point`)
    vars: list
        List of variables to optimize and set to optimum (Defaults to all continuous).
    method: string or callable
        Optimization algorithm (Defaults to 'L-BFGS-B' unless
        discrete variables are specified in `vars`, then
        `Powell` which will perform better).  For instructions on use of a callable,
        refer to SciPy's documentation of `optimize.minimize`.
    return_raw: bool
        Whether to return the full output of scipy.optimize.minimize (Defaults to `False`)
    include_transformed: bool, optional defaults to True
        Flag for reporting automatically transformed variables in addition
        to original variables.
    progressbar: bool, optional defaults to True
        Whether or not to display a progress bar in the command line.
    maxeval: int, optional, defaults to 5000
        The maximum number of times the posterior distribution is evaluated.
    model: Model (optional if in `with` context)
    *args, **kwargs
        Extra args passed to scipy.optimize.minimize

    Notes
    -----
    Older code examples used `find_MAP` to initialize the NUTS sampler,
    but this is not an effective way of choosing starting values for sampling.
    As a result, we have greatly enhanced the initialization of NUTS and
    wrapped it inside ``pymc.sample()`` and you should thus avoid this method.
    """
    model = modelcontext(model)

    if vars is None:
        vars = model.cont_vars
        if not vars:
            raise ValueError("Model has no unobserved continuous variables.")
    vars = inputvars(vars)
    disc_vars = list(typefilter(vars, discrete_types))
    allinmodel(vars, model)
    ipfn = make_initial_point_fn(
        model=model,
        jitter_rvs={},
        return_transformed=True,
        overrides=start,
    )
    if seed is None:
        seed = model.rng_seeder.randint(2**30, dtype=np.int64)
    start = ipfn(seed)
    model.check_start_vals(start)

    x0 = DictToArrayBijection.map(start)

    # TODO: If the mapping is fixed, we can simply create graphs for the
    # mapping and avoid all this bijection overhead
    def logp_func(x):
        return DictToArrayBijection.mapf(model.fastlogp_nojac)(RaveledVars(
            x, x0.point_map_info))

    try:
        # This might be needed for calls to `dlogp_func`
        # start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars)

        def dlogp_func(x):
            return DictToArrayBijection.mapf(model.fastdlogp_nojac(vars))(
                RaveledVars(x, x0.point_map_info))

        compute_gradient = True
    except (AttributeError, NotImplementedError, tg.NullTypeGradError):
        compute_gradient = False

    if disc_vars or not compute_gradient:
        pm._log.warning(
            "Warning: gradient not available." +
            "(E.g. vars contains discrete variables). MAP " +
            "estimates may not be accurate for the default " +
            "parameters. Defaulting to non-gradient minimization " +
            "'Powell'.")
        method = "Powell"

    if compute_gradient:
        cost_func = CostFuncWrapper(maxeval, progressbar, logp_func,
                                    dlogp_func)
    else:
        cost_func = CostFuncWrapper(maxeval, progressbar, logp_func)

    try:
        opt_result = minimize(cost_func,
                              x0.data,
                              method=method,
                              jac=compute_gradient,
                              *args,
                              **kwargs)
        mx0 = opt_result["x"]  # r -> opt_result
    except (KeyboardInterrupt, StopIteration) as e:
        mx0, opt_result = cost_func.previous_x, None
        if isinstance(e, StopIteration):
            pm._log.info(e)
    finally:
        last_v = cost_func.n_eval
        if progressbar:
            assert isinstance(cost_func.progress, ProgressBar)
            cost_func.progress.total = last_v
            cost_func.progress.update(last_v)
            print(file=sys.stdout)

    mx0 = RaveledVars(mx0, x0.point_map_info)

    vars = get_default_varnames(model.unobserved_value_vars,
                                include_transformed)
    mx = {
        var.name: value
        for var, value in zip(
            vars,
            model.fastfn(vars)(DictToArrayBijection.rmap(mx0)))
    }

    if return_raw:
        return mx, opt_result
    else:
        return mx