Example #1
0
class TestVariableSelection:
    @pytest.mark.parametrize(
        "var_names, vars_to_plot, compute_graph",
        [
            (["c"], ["a", "b", "c"], {
                "c": {"a", "b"},
                "a": set(),
                "b": set()
            }),
            (
                ["L"],
                ["pred", "obs", "L", "intermediate", "a", "b"],
                {
                    "pred": {"intermediate"},
                    "obs": {"L"},
                    "L": {"pred"},
                    "intermediate": {"a", "b"},
                    "a": set(),
                    "b": set(),
                },
            ),
            (
                ["obs"],
                ["pred", "obs", "L", "intermediate", "a", "b"],
                {
                    "pred": {"intermediate"},
                    "obs": {"L"},
                    "L": {"pred"},
                    "intermediate": {"a", "b"},
                    "a": set(),
                    "b": set(),
                },
            ),
            # selecting ["c", "L"] is akin to selecting the entire graph
            (
                ["c", "L"],
                ModelGraph(model_with_different_descendants()).vars_to_plot(),
                ModelGraph(
                    model_with_different_descendants()).make_compute_graph(),
            ),
        ],
    )
    def test_subgraph(self, var_names, vars_to_plot, compute_graph):
        mg = ModelGraph(model_with_different_descendants())
        assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
        assert mg.make_compute_graph(var_names=var_names) == compute_graph
Example #2
0
 def test_subgraph(self, var_names, vars_to_plot, compute_graph):
     mg = ModelGraph(model_with_different_descendants())
     assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
     assert mg.make_compute_graph(var_names=var_names) == compute_graph
Example #3
0
 def test_get_parent_names(self, var_name, parent_names):
     mg = ModelGraph(model_with_different_descendants())
     mg.get_parent_names(mg.model[var_name]) == parent_names
Example #4
0
 def setup_class(cls):
     cls.model, cls.compute_graph, cls.plates = cls.model_func()
     cls.model_graph = ModelGraph(cls.model)