Пример #1
0
def check_proposal_functions(
    model: Model,
    state: Optional[flow.SamplingState] = None,
    observed: Optional[dict] = None,
) -> bool:
    """
    Check for the non-default proposal generation functions

    Parameters
    ----------
    model : pymc4.Model
        Model to sample posterior for
    state : Optional[flow.SamplingState]
        Current state
    observed : Optional[Dict[str, Any]]
        Observed values (optional)
    """
    (_, state, _, _, continuous_distrs, discrete_distrs) = initialize_state(
        model, observed=observed, state=state
    )
    init = state.all_unobserved_values
    init_state = list(init.values())
    init_keys = list(init.keys())

    for i, state_part in enumerate(init_state):
        untrs_var, unscoped_tr_var = scope_remove_transformed_part_if_required(
            init_keys[i], state.transformed_values
        )
        # get the distribution for the random variable name
        distr = continuous_distrs.get(untrs_var, None)
        if distr is None:
            distr = discrete_distrs[untrs_var]
        func = distr._default_new_state_part
        if callable(func):
            return True
    return False
Пример #2
0
    def _sample(
        self,
        *,
        num_samples: int = 1000,
        num_chains: int = 10,
        burn_in: int = 100,
        observed: Optional[dict] = None,
        state: Optional[flow.SamplingState] = None,
        use_auto_batching: bool = True,
        xla: bool = False,
        seed: Optional[int] = None,
        is_compound: bool = False,
        trace_discrete: Optional[List[str]] = None,
    ):
        if state is not None and observed is not None:
            raise ValueError("Can't use both `state` and `observed` arguments")
        (
            logpfn,
            init,
            _deterministics_callback,
            deterministic_names,
            state_,
        ) = build_logp_and_deterministic_functions(
            self.model,
            num_chains=num_chains,
            state=state,
            observed=observed,
            collect_reduced_log_prob=use_auto_batching,
            parent_inds=self.parent_inds if is_compound else None,
        )

        init_state = list(init.values())
        init_keys = list(init.keys())

        if is_compound:
            init_state = [init_state[i] for i in self.parent_inds]
            init_keys = [init_keys[i] for i in self.parent_inds]

        if use_auto_batching:
            self.parallel_logpfn = vectorize_logp_function(logpfn)
            self.deterministics_callback = vectorize_logp_function(
                _deterministics_callback)
            init_state = tile_init(init_state, num_chains)
        else:
            self.parallel_logpfn = logpfn
            self.deterministics_callback = _deterministics_callback
            init_state = tile_init(init_state, num_chains)

        # TODO: problem with tf.function when passing as argument to self._run_chains
        self._num_samples = num_samples
        self.seed = seed

        if xla:
            results, sample_stats = tf.xla.experimental.compile(
                self._run_chains,
                inputs=[init_state, burn_in],
            )
        else:
            results, sample_stats = self._run_chains(init_state, burn_in)

        posterior = dict(zip(init_keys, results))

        if trace_discrete:
            # TODO: maybe better logic can be written here
            # The workaround to cast variables post-sample.
            # `trace_discrete` is the list of vairables that need to be casted
            # to tf.int32 after the sampling is completed.
            init_keys_ = [
                scope_remove_transformed_part_if_required(
                    _, state_.transformed_values)[1] for _ in init_keys
            ]
            discrete_indices = [init_keys_.index(_) for _ in trace_discrete]
            keys_to_cast = [init_keys[_] for _ in discrete_indices]
            for key in keys_to_cast:
                posterior[key] = tf.cast(posterior[key], dtype=tf.int32)

        # Keep in sync with pymc3 naming convention
        if len(sample_stats) > len(self.stat_names):
            deterministic_values = sample_stats[len(self.stat_names):]
            sample_stats = sample_stats[:len(self.stat_names)]
        sampler_stats = dict(zip(self.stat_names, sample_stats))
        if len(deterministic_names) > 0:
            posterior.update(
                dict(zip(deterministic_names, deterministic_values)))

        return trace_to_arviz(
            posterior,
            sampler_stats if is_compound is False else None,
            observed_data=state_.observed_values,
        )
Пример #3
0
    def _assign_default_methods(
        self,
        *,
        sampler_methods: Optional[List] = None,
        state: Optional[flow.SamplingState] = None,
        observed: Optional[dict] = None,
    ):
        converted_sampler_methods: List = CompoundStep._convert_sampler_methods(
            sampler_methods)

        (_, state, _, _, continuous_distrs,
         discrete_distrs) = initialize_state(self.model,
                                             observed=observed,
                                             state=state)
        init = state.all_unobserved_values
        init_state = list(init.values())
        init_keys = list(init.keys())

        # assignd samplers for free variables
        make_kernel_fn: list = []
        # user passed kwargs for each sampler in `make_kernel_fn`
        part_kernel_kwargs: list = []
        # keep the list for proposal func names
        func_names: list = []

        for i, state_part in enumerate(init_state):
            untrs_var, unscoped_tr_var = scope_remove_transformed_part_if_required(
                init_keys[i], state.transformed_values)
            # get the distribution for the random variable name

            distr = continuous_distrs.get(untrs_var, None)
            if distr is None:
                distr = discrete_distrs[untrs_var]

            # get custom `new_state_fn` for the distribution
            func = distr._default_new_state_part

            # simplest way of assigning sampling methods
            # if the sampler_methods was passed and if a var is provided
            # then the var will be assigned to the given sampler
            # but will also be checked if the sampler supports the distr

            # 1. If sampler is provided by the user, we create new sampler
            #    and add to `make_kernel_fn`
            # 2. If the distribution has `new_state_fn` then the new sampler
            #    should be create also. Because sampler is initialized with
            #    the `new_state_fn` argument.
            if unscoped_tr_var in converted_sampler_methods:
                sampler, kwargs = converted_sampler_methods[unscoped_tr_var]

                # check for the sampler able to sampler from the distribution
                if not distr._grad_support and sampler._grad:
                    raise ValueError(
                        "The `{}` doesn't support gradient, please provide an appropriate sampler method"
                        .format(unscoped_tr_var))

                # add sampler to the dict
                make_kernel_fn.append(sampler)
                part_kernel_kwargs.append({})
                # update with user provided kwargs
                part_kernel_kwargs[-1].update(kwargs)
                # if proposal function is provided then replace
                func = part_kernel_kwargs[-1].get("new_state_fn", func)
                # add the default `new_state_fn` for the distr
                # `new_state_fn` is supported for only RandomWalkMetropolis transition
                # kernel.
                if func and sampler._name == "rwm":
                    part_kernel_kwargs[-1]["new_state_fn"] = partial(func)()
            elif callable(func):
                # If distribution has defined `new_state_fn` attribute then we need
                # to assign `RandomWalkMetropolis` transition kernel
                make_kernel_fn.append(RandomWalkM)
                part_kernel_kwargs.append({})
                part_kernel_kwargs[-1]["new_state_fn"] = partial(func)()
            else:
                # by default if user didn't not provide any sampler
                # we choose NUTS for the variable with gradient and
                # RWM for the variable without the gradient
                sampler = NUTS if distr._grad_support else RandomWalkM
                make_kernel_fn.append(sampler)
                part_kernel_kwargs.append({})
                # _log.info("Auto-assigning NUTS sampler...")
            # save proposal func names
            func_names.append(func._name if func else "default")

        # `make_kernel_fn` contains (len(state)) sampler methods, this could lead
        # to more overhed when we are iterating at each call of `one_step` in the
        # compound step kernel. For that we need to merge some of the samplers.
        kernels, set_lengths = self._merge_samplers(make_kernel_fn,
                                                    part_kernel_kwargs)
        # log variable sampler mapping
        CompoundStep._log_variables(init_keys, kernels, set_lengths,
                                    self.parent_inds, func_names)
        # save to use late for compound kernel init
        self.kernel_kwargs["compound_samplers"] = kernels
        self.kernel_kwargs["compound_set_lengths"] = set_lengths