def build_logp_and_deterministic_functions( model, observed: Optional[dict] = None, state: Optional[flow.SamplingState] = None): if not isinstance(model, Model): raise TypeError( "`sample` function only supports `pymc4.Model` objects, but you've passed `{}`" .format(type(model))) if state is not None and observed is not None: raise ValueError("Can't use both `state` and `observed` arguments") state, deterministic_names = initialize_sampling_state(model, observed=observed, state=state) if not state.all_unobserved_values: raise ValueError( f"Can not calculate a log probability: the model {model.name or ''} has no unobserved values." ) observed_var = state.observed_values unobserved_keys, unobserved_values = zip( *state.all_unobserved_values.items()) @tf.function(autograph=False) def logpfn(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed) _, st = flow.evaluate_model_transformed(model, state=st) return st.collect_log_prob() @tf.function(autograph=False) def deterministics_callback(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed_var) _, st = flow.evaluate_model_transformed(model, state=st) for transformed_name in st.transformed_values: untransformed_name = NameParts.from_name( transformed_name).full_untransformed_name st.deterministics[ untransformed_name] = st.untransformed_values.pop( untransformed_name) return st.deterministics.values() return ( logpfn, dict(state.all_unobserved_values), deterministics_callback, deterministic_names, state, )
def __init__(self, model: Optional[Model] = None, random_seed: Optional[int] = None): if not isinstance(model, Model): raise TypeError( "`fit` function only supports `pymc4.Model` objects, but you've passed `{}`" .format(type(model))) self.model = model self._seed = random_seed self.state, self.deterministic_names = initialize_sampling_state(model) if not self.state.all_unobserved_values: raise ValueError( f"Can not calculate a log probability: the model {model.name or ''} has no unobserved values." ) self.unobserved_keys = self.state.all_unobserved_values.keys() self.target_log_prob = self._build_logfn() self.approx = self._build_posterior()
def build_logp_and_deterministic_functions( model, num_chains: Optional[int] = None, observed: Optional[dict] = None, state: Optional[flow.SamplingState] = None, collect_reduced_log_prob: bool = True, ): if not isinstance(model, Model): raise TypeError( "`sample` function only supports `pymc4.Model` objects, but you've passed `{}`".format( type(model) ) ) if state is not None and observed is not None: raise ValueError("Can't use both `state` and `observed` arguments") state, deterministic_names = initialize_sampling_state(model, observed=observed, state=state) if not state.all_unobserved_values: raise ValueError( f"Can not calculate a log probability: the model {model.name or ''} has no unobserved values." ) observed_var = state.observed_values unobserved_keys, unobserved_values = zip(*state.all_unobserved_values.items()) if collect_reduced_log_prob: @tf.function(autograph=False) def logpfn(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed) _, st = flow.evaluate_model_transformed(model, state=st) return st.collect_log_prob() else: # When we use manual batching, we need to manually tile the chains axis # to the left of the observed tensors if num_chains is not None: obs = state.observed_values if observed is not None: obs.update(observed) else: observed = obs for k, o in obs.items(): o = tf.convert_to_tensor(o) o = tf.tile(o[None, ...], [num_chains] + [1] * o.ndim) observed[k] = o @tf.function(autograph=False) def logpfn(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed) _, st = flow.evaluate_model_transformed(model, state=st) return st.collect_unreduced_log_prob() @tf.function(autograph=False) def deterministics_callback(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed_var) _, st = flow.evaluate_model_transformed(model, state=st) for transformed_name in st.transformed_values: untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name st.deterministics[untransformed_name] = st.untransformed_values.pop(untransformed_name) return st.deterministics.values() return ( logpfn, dict(state.all_unobserved_values), deterministics_callback, deterministic_names, state, )