class TestVariableSelection: @pytest.mark.parametrize( "var_names, vars_to_plot, compute_graph", [ (["c"], ["a", "b", "c"], { "c": {"a", "b"}, "a": set(), "b": set() }), ( ["L"], ["pred", "obs", "L", "intermediate", "a", "b"], { "pred": {"intermediate"}, "obs": {"L"}, "L": {"pred"}, "intermediate": {"a", "b"}, "a": set(), "b": set(), }, ), ( ["obs"], ["pred", "obs", "L", "intermediate", "a", "b"], { "pred": {"intermediate"}, "obs": {"L"}, "L": {"pred"}, "intermediate": {"a", "b"}, "a": set(), "b": set(), }, ), # selecting ["c", "L"] is akin to selecting the entire graph ( ["c", "L"], ModelGraph(model_with_different_descendants()).vars_to_plot(), ModelGraph( model_with_different_descendants()).make_compute_graph(), ), ], ) def test_subgraph(self, var_names, vars_to_plot, compute_graph): mg = ModelGraph(model_with_different_descendants()) assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot) assert mg.make_compute_graph(var_names=var_names) == compute_graph
def test_subgraph(self, var_names, vars_to_plot, compute_graph): mg = ModelGraph(model_with_different_descendants()) assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot) assert mg.make_compute_graph(var_names=var_names) == compute_graph
def test_get_parent_names(self, var_name, parent_names): mg = ModelGraph(model_with_different_descendants()) mg.get_parent_names(mg.model[var_name]) == parent_names
def setup_class(cls): cls.model, cls.compute_graph, cls.plates = cls.model_func() cls.model_graph = ModelGraph(cls.model)