def check_proposal_functions( model: Model, state: Optional[flow.SamplingState] = None, observed: Optional[dict] = None, ) -> bool: """ Check for the non-default proposal generation functions Parameters ---------- model : pymc4.Model Model to sample posterior for state : Optional[flow.SamplingState] Current state observed : Optional[Dict[str, Any]] Observed values (optional) """ (_, state, _, _, continuous_distrs, discrete_distrs) = initialize_state( model, observed=observed, state=state ) init = state.all_unobserved_values init_state = list(init.values()) init_keys = list(init.keys()) for i, state_part in enumerate(init_state): untrs_var, unscoped_tr_var = scope_remove_transformed_part_if_required( init_keys[i], state.transformed_values ) # get the distribution for the random variable name distr = continuous_distrs.get(untrs_var, None) if distr is None: distr = discrete_distrs[untrs_var] func = distr._default_new_state_part if callable(func): return True return False
def _sample( self, *, num_samples: int = 1000, num_chains: int = 10, burn_in: int = 100, observed: Optional[dict] = None, state: Optional[flow.SamplingState] = None, use_auto_batching: bool = True, xla: bool = False, seed: Optional[int] = None, is_compound: bool = False, trace_discrete: Optional[List[str]] = None, ): if state is not None and observed is not None: raise ValueError("Can't use both `state` and `observed` arguments") ( logpfn, init, _deterministics_callback, deterministic_names, state_, ) = build_logp_and_deterministic_functions( self.model, num_chains=num_chains, state=state, observed=observed, collect_reduced_log_prob=use_auto_batching, parent_inds=self.parent_inds if is_compound else None, ) init_state = list(init.values()) init_keys = list(init.keys()) if is_compound: init_state = [init_state[i] for i in self.parent_inds] init_keys = [init_keys[i] for i in self.parent_inds] if use_auto_batching: self.parallel_logpfn = vectorize_logp_function(logpfn) self.deterministics_callback = vectorize_logp_function( _deterministics_callback) init_state = tile_init(init_state, num_chains) else: self.parallel_logpfn = logpfn self.deterministics_callback = _deterministics_callback init_state = tile_init(init_state, num_chains) # TODO: problem with tf.function when passing as argument to self._run_chains self._num_samples = num_samples self.seed = seed if xla: results, sample_stats = tf.xla.experimental.compile( self._run_chains, inputs=[init_state, burn_in], ) else: results, sample_stats = self._run_chains(init_state, burn_in) posterior = dict(zip(init_keys, results)) if trace_discrete: # TODO: maybe better logic can be written here # The workaround to cast variables post-sample. # `trace_discrete` is the list of vairables that need to be casted # to tf.int32 after the sampling is completed. init_keys_ = [ scope_remove_transformed_part_if_required( _, state_.transformed_values)[1] for _ in init_keys ] discrete_indices = [init_keys_.index(_) for _ in trace_discrete] keys_to_cast = [init_keys[_] for _ in discrete_indices] for key in keys_to_cast: posterior[key] = tf.cast(posterior[key], dtype=tf.int32) # Keep in sync with pymc3 naming convention if len(sample_stats) > len(self.stat_names): deterministic_values = sample_stats[len(self.stat_names):] sample_stats = sample_stats[:len(self.stat_names)] sampler_stats = dict(zip(self.stat_names, sample_stats)) if len(deterministic_names) > 0: posterior.update( dict(zip(deterministic_names, deterministic_values))) return trace_to_arviz( posterior, sampler_stats if is_compound is False else None, observed_data=state_.observed_values, )
def _assign_default_methods( self, *, sampler_methods: Optional[List] = None, state: Optional[flow.SamplingState] = None, observed: Optional[dict] = None, ): converted_sampler_methods: List = CompoundStep._convert_sampler_methods( sampler_methods) (_, state, _, _, continuous_distrs, discrete_distrs) = initialize_state(self.model, observed=observed, state=state) init = state.all_unobserved_values init_state = list(init.values()) init_keys = list(init.keys()) # assignd samplers for free variables make_kernel_fn: list = [] # user passed kwargs for each sampler in `make_kernel_fn` part_kernel_kwargs: list = [] # keep the list for proposal func names func_names: list = [] for i, state_part in enumerate(init_state): untrs_var, unscoped_tr_var = scope_remove_transformed_part_if_required( init_keys[i], state.transformed_values) # get the distribution for the random variable name distr = continuous_distrs.get(untrs_var, None) if distr is None: distr = discrete_distrs[untrs_var] # get custom `new_state_fn` for the distribution func = distr._default_new_state_part # simplest way of assigning sampling methods # if the sampler_methods was passed and if a var is provided # then the var will be assigned to the given sampler # but will also be checked if the sampler supports the distr # 1. If sampler is provided by the user, we create new sampler # and add to `make_kernel_fn` # 2. If the distribution has `new_state_fn` then the new sampler # should be create also. Because sampler is initialized with # the `new_state_fn` argument. if unscoped_tr_var in converted_sampler_methods: sampler, kwargs = converted_sampler_methods[unscoped_tr_var] # check for the sampler able to sampler from the distribution if not distr._grad_support and sampler._grad: raise ValueError( "The `{}` doesn't support gradient, please provide an appropriate sampler method" .format(unscoped_tr_var)) # add sampler to the dict make_kernel_fn.append(sampler) part_kernel_kwargs.append({}) # update with user provided kwargs part_kernel_kwargs[-1].update(kwargs) # if proposal function is provided then replace func = part_kernel_kwargs[-1].get("new_state_fn", func) # add the default `new_state_fn` for the distr # `new_state_fn` is supported for only RandomWalkMetropolis transition # kernel. if func and sampler._name == "rwm": part_kernel_kwargs[-1]["new_state_fn"] = partial(func)() elif callable(func): # If distribution has defined `new_state_fn` attribute then we need # to assign `RandomWalkMetropolis` transition kernel make_kernel_fn.append(RandomWalkM) part_kernel_kwargs.append({}) part_kernel_kwargs[-1]["new_state_fn"] = partial(func)() else: # by default if user didn't not provide any sampler # we choose NUTS for the variable with gradient and # RWM for the variable without the gradient sampler = NUTS if distr._grad_support else RandomWalkM make_kernel_fn.append(sampler) part_kernel_kwargs.append({}) # _log.info("Auto-assigning NUTS sampler...") # save proposal func names func_names.append(func._name if func else "default") # `make_kernel_fn` contains (len(state)) sampler methods, this could lead # to more overhed when we are iterating at each call of `one_step` in the # compound step kernel. For that we need to merge some of the samplers. kernels, set_lengths = self._merge_samplers(make_kernel_fn, part_kernel_kwargs) # log variable sampler mapping CompoundStep._log_variables(init_keys, kernels, set_lengths, self.parent_inds, func_names) # save to use late for compound kernel init self.kernel_kwargs["compound_samplers"] = kernels self.kernel_kwargs["compound_set_lengths"] = set_lengths