def initialize_sampling_state( model: Model, observed: Optional[dict] = None, state: Optional[flow.SamplingState] = None ) -> Tuple[flow.SamplingState, List[str]]: """ Initilize the model provided state and/or observed variables. Parameters ---------- model : pymc4.Model observed : Optional[dict] state : Optional[flow.SamplingState] Returns ------- state: pymc4.flow.SamplingState The model's sampling state deterministic_names: List[str] The list of names of the model's deterministics """ _, state = flow.evaluate_model_transformed(model, observed=observed, state=state) deterministic_names = list(state.deterministics) state, transformed_names = state.as_sampling_state() return state, deterministic_names + transformed_names
def deterministics_callback(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed) _, st = flow.evaluate_model_transformed(model, state=st) return st.deterministics.values()
def logpfn(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed) _, st = flow.evaluate_model_transformed(model, state=st) return st.collect_log_prob()
def extract_log_likelihood(values, observed_rv): st = flow.SamplingState.from_values( values, observed_values=sampling_state.observed_values) _, st = flow.evaluate_model_transformed(model, state=st) try: dist = st.continuous_distributions[observed_rv] except KeyError: dist = st.discrete_distributions[observed_rv] return dist.log_prob(dist.model_info["observed"])
def deterministics_callback(q_samples): st = flow.SamplingState.from_values( q_samples, observed_values=self.state.observed_values) _, st = flow.evaluate_model_transformed(self.model, state=st) for transformed_name in st.transformed_values: untransformed_name = NameParts.from_name( transformed_name).full_untransformed_name st.deterministics[ untransformed_name] = st.untransformed_values.pop( untransformed_name) return st.deterministics
def initialize_state( model: Model, observed: Optional[dict] = None, state: Optional[flow.SamplingState] = None, ) -> Tuple[ flow.SamplingState, flow.SamplingState, List[str], List[str], Dict[str, distribution.Distribution], Dict[str, distribution.Distribution], ]: """ Get list of discrete/continuous distributions Parameters ---------- model : pymc4.Model observed : Optional[dict] state : Optional[flow.SamplingState] Returns ------- state: Model Unsampled version of sample object sampling_state: The model's sampling state free_discrete_names: List[str] The list of free discrete variables free_continuous_names: List[str] The list of free continuous variables cont_distr: List[distribution.Distribution] The list of all continous distributions disc_distr: List[distribution.Distribution] The list of all discrete distributions """ _, state = flow.evaluate_model_transformed(model) free_discrete_names, free_continuous_names = ( list(state.discrete_distributions), list(state.continuous_distributions), ) observed_rvs = list(state.observed_values.keys()) free_discrete_names = list(filter(lambda x: x not in observed_rvs, free_discrete_names)) free_continuous_names = list(filter(lambda x: x not in observed_rvs, free_continuous_names)) sampling_state = None cont_distrs = state.continuous_distributions disc_distrs = state.discrete_distributions sampling_state, _ = state.as_sampling_state() return ( state, sampling_state, free_discrete_names, free_continuous_names, cont_distrs, disc_distrs, )
def deterministics_callback(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed_var) _, st = flow.evaluate_model_transformed(model, state=st) for transformed_name in st.transformed_values: untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name st.deterministics[untransformed_name] = st.untransformed_values.pop(untransformed_name) return st.deterministics.values()
def build_logp_and_deterministic_functions( model, observed: Optional[dict] = None, state: Optional[flow.SamplingState] = None): if not isinstance(model, Model): raise TypeError( "`sample` function only supports `pymc4.Model` objects, but you've passed `{}`" .format(type(model))) if state is not None and observed is not None: raise ValueError("Can't use both `state` and `observed` arguments") if state is None: state, deterministic_names = initialize_state(model, observed=observed) else: _, st = flow.evaluate_model_transformed(model, state=state) deterministic_names = list(st.deterministics) state = state.as_sampling_state() observed = state.observed_values unobserved_keys, unobserved_values = zip( *state.all_unobserved_values.items()) @tf.function(autograph=False) def logpfn(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed) _, st = flow.evaluate_model_transformed(model, state=st) return st.collect_log_prob() @tf.function(autograph=False) def deterministics_callback(*values, **kwargs): if kwargs and values: raise TypeError("Either list state should be passed or a dict one") elif values: kwargs = dict(zip(unobserved_keys, values)) st = flow.SamplingState.from_values(kwargs, observed_values=observed) _, st = flow.evaluate_model_transformed(model, state=st) return st.deterministics.values() return logpfn, dict(state.all_unobserved_values ), deterministics_callback, deterministic_names