Exemplo n.º 1
0
def initialize_sampling_state(
    model: Model,
    observed: Optional[dict] = None,
    state: Optional[flow.SamplingState] = None
) -> Tuple[flow.SamplingState, List[str]]:
    """
    Initilize the model provided state and/or observed variables.

    Parameters
    ----------
    model : pymc4.Model
    observed : Optional[dict]
    state : Optional[flow.SamplingState]

    Returns
    -------
    state: pymc4.flow.SamplingState
        The model's sampling state
    deterministic_names: List[str]
        The list of names of the model's deterministics
    """
    _, state = flow.evaluate_model_transformed(model,
                                               observed=observed,
                                               state=state)
    deterministic_names = list(state.deterministics)

    state, transformed_names = state.as_sampling_state()
    return state, deterministic_names + transformed_names
Exemplo n.º 2
0
 def deterministics_callback(*values, **kwargs):
     if kwargs and values:
         raise TypeError("Either list state should be passed or a dict one")
     elif values:
         kwargs = dict(zip(unobserved_keys, values))
     st = flow.SamplingState.from_values(kwargs, observed_values=observed)
     _, st = flow.evaluate_model_transformed(model, state=st)
     return st.deterministics.values()
Exemplo n.º 3
0
 def logpfn(*values, **kwargs):
     if kwargs and values:
         raise TypeError("Either list state should be passed or a dict one")
     elif values:
         kwargs = dict(zip(unobserved_keys, values))
     st = flow.SamplingState.from_values(kwargs, observed_values=observed)
     _, st = flow.evaluate_model_transformed(model, state=st)
     return st.collect_log_prob()
Exemplo n.º 4
0
 def extract_log_likelihood(values, observed_rv):
     st = flow.SamplingState.from_values(
         values, observed_values=sampling_state.observed_values)
     _, st = flow.evaluate_model_transformed(model, state=st)
     try:
         dist = st.continuous_distributions[observed_rv]
     except KeyError:
         dist = st.discrete_distributions[observed_rv]
     return dist.log_prob(dist.model_info["observed"])
Exemplo n.º 5
0
 def deterministics_callback(q_samples):
     st = flow.SamplingState.from_values(
         q_samples, observed_values=self.state.observed_values)
     _, st = flow.evaluate_model_transformed(self.model, state=st)
     for transformed_name in st.transformed_values:
         untransformed_name = NameParts.from_name(
             transformed_name).full_untransformed_name
         st.deterministics[
             untransformed_name] = st.untransformed_values.pop(
                 untransformed_name)
     return st.deterministics
Exemplo n.º 6
0
def initialize_state(
    model: Model,
    observed: Optional[dict] = None,
    state: Optional[flow.SamplingState] = None,
) -> Tuple[
    flow.SamplingState,
    flow.SamplingState,
    List[str],
    List[str],
    Dict[str, distribution.Distribution],
    Dict[str, distribution.Distribution],
]:
    """
    Get list of discrete/continuous distributions
    Parameters
    ----------
    model : pymc4.Model
    observed : Optional[dict]
    state : Optional[flow.SamplingState]
    Returns
    -------
    state: Model
        Unsampled version of sample object
    sampling_state:
        The model's sampling state
    free_discrete_names: List[str]
        The list of free discrete variables
    free_continuous_names: List[str]
        The list of free continuous variables
    cont_distr: List[distribution.Distribution]
        The list of all continous distributions
    disc_distr: List[distribution.Distribution]
        The list of all discrete distributions
    """
    _, state = flow.evaluate_model_transformed(model)
    free_discrete_names, free_continuous_names = (
        list(state.discrete_distributions),
        list(state.continuous_distributions),
    )
    observed_rvs = list(state.observed_values.keys())
    free_discrete_names = list(filter(lambda x: x not in observed_rvs, free_discrete_names))
    free_continuous_names = list(filter(lambda x: x not in observed_rvs, free_continuous_names))
    sampling_state = None
    cont_distrs = state.continuous_distributions
    disc_distrs = state.discrete_distributions
    sampling_state, _ = state.as_sampling_state()
    return (
        state,
        sampling_state,
        free_discrete_names,
        free_continuous_names,
        cont_distrs,
        disc_distrs,
    )
Exemplo n.º 7
0
 def deterministics_callback(*values, **kwargs):
     if kwargs and values:
         raise TypeError("Either list state should be passed or a dict one")
     elif values:
         kwargs = dict(zip(unobserved_keys, values))
     st = flow.SamplingState.from_values(kwargs, observed_values=observed_var)
     _, st = flow.evaluate_model_transformed(model, state=st)
     for transformed_name in st.transformed_values:
         untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name
         st.deterministics[untransformed_name] = st.untransformed_values.pop(untransformed_name)
     return st.deterministics.values()
Exemplo n.º 8
0
def build_logp_and_deterministic_functions(
        model,
        observed: Optional[dict] = None,
        state: Optional[flow.SamplingState] = None):
    if not isinstance(model, Model):
        raise TypeError(
            "`sample` function only supports `pymc4.Model` objects, but you've passed `{}`"
            .format(type(model)))
    if state is not None and observed is not None:
        raise ValueError("Can't use both `state` and `observed` arguments")
    if state is None:
        state, deterministic_names = initialize_state(model, observed=observed)
    else:
        _, st = flow.evaluate_model_transformed(model, state=state)
        deterministic_names = list(st.deterministics)
        state = state.as_sampling_state()

    observed = state.observed_values
    unobserved_keys, unobserved_values = zip(
        *state.all_unobserved_values.items())

    @tf.function(autograph=False)
    def logpfn(*values, **kwargs):
        if kwargs and values:
            raise TypeError("Either list state should be passed or a dict one")
        elif values:
            kwargs = dict(zip(unobserved_keys, values))
        st = flow.SamplingState.from_values(kwargs, observed_values=observed)
        _, st = flow.evaluate_model_transformed(model, state=st)
        return st.collect_log_prob()

    @tf.function(autograph=False)
    def deterministics_callback(*values, **kwargs):
        if kwargs and values:
            raise TypeError("Either list state should be passed or a dict one")
        elif values:
            kwargs = dict(zip(unobserved_keys, values))
        st = flow.SamplingState.from_values(kwargs, observed_values=observed)
        _, st = flow.evaluate_model_transformed(model, state=st)
        return st.deterministics.values()

    return logpfn, dict(state.all_unobserved_values
                        ), deterministics_callback, deterministic_names