def init(self) -> None:
        """This method carries out the initialization phase of sampling
        from the posterior predictive distribution.  Notably it initializes the
        ``_DrawValuesContext`` bookkeeping object and evaluates the "fast drawable"
        parts of the model."""
        vars: List[Any] = self.vars
        trace: _TraceDict = self.trace
        samples: int = self.samples
        leaf_nodes: Dict[str, Any]
        named_nodes_parents: Dict[str, Any]
        named_nodes_children: Dict[str, Any]

        # initialization phase
        context = _DrawValuesContext.get_context()
        assert isinstance(context, _DrawValuesContext)
        with context:
            drawn = context.drawn_vars
            evaluated: Dict[int, Any] = {}
            symbolic_params = []
            for i, var in enumerate(vars):
                if is_fast_drawable(var):
                    evaluated[i] = self.draw_value(var)
                    continue
                name = getattr(var, "name", None)
                if (var, samples) in drawn:
                    evaluated[i] = drawn[(var, samples)]
                    # We filter out Deterministics by checking for `model` attribute
                elif name is not None and hasattr(var, "model") and name in trace.varnames:
                    # param.name is in the trace.  Record it as drawn and evaluated
                    drawn[(var, samples)] = evaluated[i] = trace[cast(str, name)]
                else:
                    # param still needs to be drawn
                    symbolic_params.append((i, var))
        self.evaluated = evaluated
        self.symbolic_params = symbolic_params
Example #2
0
def test_mixed_contexts():
    modelA = Model()
    modelB = Model()
    with raises((ValueError, TypeError)):
        modelcontext(None)
    with modelA:
        with modelB:
            assert Model.get_context() == modelB
            assert modelcontext(None) == modelB
            dvc = _DrawValuesContext()
            with dvc:
                assert Model.get_context() == modelB
                assert modelcontext(None) == modelB
                assert _DrawValuesContext.get_context() == dvc
                dvcb = _DrawValuesContextBlocker()
                with dvcb:
                    assert _DrawValuesContext.get_context() == dvcb
                    assert _DrawValuesContextBlocker.get_context() == dvcb
                assert _DrawValuesContext.get_context() == dvc
                assert _DrawValuesContextBlocker.get_context() is dvc
                assert Model.get_context() == modelB
                assert modelcontext(None) == modelB
            assert _DrawValuesContext.get_context(error_if_none=False) is None
            with raises(TypeError):
                _DrawValuesContext.get_context()
            assert Model.get_context() == modelB
            assert modelcontext(None) == modelB
        assert Model.get_context() == modelA
        assert modelcontext(None) == modelA
    assert Model.get_context(error_if_none=False) is None
    with raises(TypeError):
        Model.get_context(error_if_none=True)
    with raises((ValueError, TypeError)):
        modelcontext(None)