Exemplo n.º 1
0
    def test_prog(self):
        """Layout should be based on given prog"""
        sm = StructureModel([("a", "b")])
        a = plot_structure(sm, prog="neato")
        b = plot_structure(sm, prog="neato")
        c = plot_structure(sm, prog="dot")

        assert str(a) == str(b)
        assert str(a) != str(c)
Exemplo n.º 2
0
    def test_graph_attributes(self):
        """graph attributes should be set correctly"""

        sm = StructureModel([("a", "b")])

        a_graph = plot_structure(sm)
        assert "label" not in a_graph.graph_attr.keys()

        a_graph = plot_structure(sm, graph_attributes={"label": "test"})
        assert a_graph.graph_attr["label"] == "test"
Exemplo n.º 3
0
    def plot_dag(self, enforce_dag: bool = False, filename: str = "./graph.png"):
        """ Util function used to plot the fitted graph """

        try:
            # pylint: disable=import-outside-toplevel
            from IPython.display import Image
        except ImportError as e:
            raise ImportError(
                "DAGRegressor.plot_dag method requires IPython installed."
            ) from e

        check_is_fitted(self, "graph_")

        graph = copy.deepcopy(self.graph_)
        if enforce_dag:
            graph.threshold_till_dag()

        # silence annoying plotting warning
        warnings.filterwarnings("ignore")

        viz = plot_structure(
            graph,
            graph_attributes={"scale": "0.5"},
            all_node_attributes=NODE_STYLE.WEAK,
            all_edge_attributes=EDGE_STYLE.WEAK,
        )
        viz.draw(filename)

        # reset warnings to always show
        warnings.simplefilter("always")
        return Image(filename)
Exemplo n.º 4
0
def plot_graph(structural_model, layout='dot', rename_node_dict=None):
    """
    Plot structural model graph with some default settings.
    rename:True -> whether to rename the nodes with upper case and spacing
    """

    # edge_attr = _make_edge_attributes(structural_model)
    node_attr = _make_node_attributes(structural_model)

    for node in structural_model.nodes:
        node_attr[node]['label'] = node.replace('_', '\n').title()
        if 'Response' in node or 'response' in node:
            node_attr[node]['fillcolor'] = '#DF5F00'

    if rename_node_dict is not None:
        assert type(rename_node_dict) == dict
        for node, new_name in rename_node_dict.items():
            node_attr[node]['label'] = new_name

    viz = plot_structure(
        structural_model,
        prog=layout,
        graph_attributes=graph_attributes,
        node_attributes=node_attr,
        # edge_attributes=edge_attr

    )
    f = viz.draw(format='png')
    return Image(f)
Exemplo n.º 5
0
 def test_node_color(self, test_input, expected):
     """Node color should be set if given"""
     sm = StructureModel([("a", "b")])
     _, ax, _ = plot_structure(sm, node_color=test_input)
     assert all(
         all(face_color == expected)
         for face_color in ax.collections[0].get_facecolors())
Exemplo n.º 6
0
    def test_all_nodes_exist(self):
        """Both connected and unconnected nodes should exist"""
        sm = StructureModel([("a", "b")])
        sm.add_node("c")
        a_graph = plot_structure(sm)

        assert all(node in a_graph.nodes() for node in ["a", "b", "c"])
Exemplo n.º 7
0
    def test_all_edges_exist(self):
        """All edges in original graph should exist in pygraphviz graph"""
        edges = [(str(a), str(a + b + 1)) for a in range(2) for b in range(3)]
        sm = StructureModel(edges)
        a_graph = plot_structure(sm)

        assert all(edge in a_graph.edges() for edge in edges)
Exemplo n.º 8
0
    def test_all_node_attributes(self):
        """all node attributes should be set correctly"""
        sm = StructureModel([("a", "b")])
        a_graph = plot_structure(sm)

        default_color = a_graph.get_node("a").attr["color"]
        test_color = "black"

        assert default_color != test_color
        assert all(
            a_graph.get_node(node).attr["color"] != test_color
            for node in a_graph.nodes())

        a_graph = plot_structure(sm, all_node_attributes={"color": test_color})
        assert all(
            a_graph.get_node(node).attr["color"] == test_color
            for node in a_graph.nodes())
Exemplo n.º 9
0
    def test_all_edge_attributes(self):
        """all edge attributes should be set correctly"""
        sm = StructureModel([("a", "b"), ("b", "c")])
        a_graph = plot_structure(sm)

        default_color = a_graph.get_edge("a", "b").attr["color"]
        test_color = "black"

        assert default_color != test_color
        assert all(
            a_graph.get_edge(u, v).attr["color"] != test_color
            for u, v in a_graph.edges())

        a_graph = plot_structure(sm, all_edge_attributes={"color": test_color})
        assert all(
            a_graph.get_edge(u, v).attr["color"] == test_color
            for u, v in a_graph.edges())
Exemplo n.º 10
0
    def plot_dag(
        self,
        enforce_dag: bool = False,
        plot_structure_kwargs: Dict = None,
        use_mpl: bool = True,
        ax: Axes = None,
        pixel_size_in: float = 0.01,
    ) -> Union[Tuple[Figure, Axes], Image]:
        """
        Plot the DAG of the fitted model.
        Args:
            enforce_dag: Whether to threshold the model until it is a DAG.
            Does not alter the underlying model.

            ax: Matplotlib axes to plot the model on.
            If None, creates axis.

            pixel_size_in: Scaling multiple for the plot.

            plot_structure_kwargs: Dictionary of kwargs for the causalnex plotting module.

            use_mpl: Whether to use matplotlib as the backend.
            If False, ax and pixel_size_in are ignored.

        Returns:
            Plot of the DAG.
        """

        # handle thresholding
        check_is_fitted(self)
        graph = copy.deepcopy(self.graph_)
        if enforce_dag:
            graph.threshold_till_dag()

        # handle the plot kwargs
        plt_kwargs_default = {
            "graph_attributes": {
                "scale": "0.5"
            },
            "all_node_attributes": NODE_STYLE.WEAK,
            "all_edge_attributes": EDGE_STYLE.WEAK,
        }
        plt_kwargs = (plot_structure_kwargs
                      if plot_structure_kwargs else plt_kwargs_default)
        prog = plt_kwargs.get("prog", "neato")

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            # get pygraphviz plot:
            viz = plot_structure(graph, **plt_kwargs)

        if use_mpl is True:
            return display_plot_mpl(viz=viz,
                                    prog=prog,
                                    ax=ax,
                                    pixel_size_in=pixel_size_in)
        return display_plot_ipython(viz=viz, prog=prog)
Exemplo n.º 11
0
    def test_nodes_exist(self):
        """All nodes should exist"""

        for num_nodes in range(2, 10):
            nodes = [c for i, c in enumerate(ascii_lowercase) if i < num_nodes]
            sm = StructureModel(list(zip(nodes[:-1], nodes[1:])))
            _, ax, _ = plot_structure(sm)
            ax_nodes = ax.collections[0].get_offsets()
            assert len(ax_nodes) == num_nodes
Exemplo n.º 12
0
 def test_edge_color(self, test_input, expected):
     """Edge color should be set if given"""
     sm = StructureModel([("a", "b")])
     _, ax, _ = plot_structure(sm, edge_color=test_input)
     ax_edges = [
         patch for patch in ax.patches
         if isinstance(patch, plt.patches.FancyArrowPatch)
     ]
     assert ax_edges[0].get_edgecolor() == expected
Exemplo n.º 13
0
def determine_structure():

    _, _, _ = plot_structure(sm)

    sm.remove_edges_below_threshold(0.8)
    _, _, _ = plot_structure(sm)
    """
    Now I have to determine what relationships are right.
    I can see that BAD determines VALUE and MORTDUE when it should be the other way
    round. SO I am going to change the arrows. 
    """
    sm.remove_edge("BAD", "VALUE")
    sm.remove_edge("BAD", "MORTDUE")
    sm.remove_edge("BAD", "LOAN")
    sm.add_edge("MORTDUE", "BAD")
    sm.add_edge("VALUE", "BAD")
    """
    DEBTINC is debt-to-income ratio so mortgage and salary affects this variable,
    not the other way round.
    
    """
    sm.remove_edge("DEBTINC", "CLAGE")
    sm.remove_edge("DEBTINC", "VALUE")
    sm.remove_edge("DEBTINC", "MORTDUE")
    sm.remove_edge("DEBTINC", "LOAN")

    sm.add_edge("MORTDUE", "DEBTINC")
    sm.add_edge("VALUE", "DEBTINC")
    sm.add_edge("CLAGE", "DEBTINC")
    sm.add_edge("LOAN", "DEBTINC")
    """
    NINQ is number of inquires, so variables are the other way round
    not the other way round.
    
    """

    sm.remove_edge("NINQ", "VALUE")
    sm.remove_edge("NINQ", "MORTDUE")
    sm.remove_edge("NINQ", "LOAN")

    sm.add_edge("MORTDUE", "NINQ")
    sm.add_edge("VALUE", "NINQ")
    sm.add_edge("LOAN", "NINQ")
Exemplo n.º 14
0
    def test_edges_exist(self):
        """All edges should exist"""

        for num_nodes in range(2, 10):
            nodes = [c for i, c in enumerate(ascii_lowercase) if i < num_nodes]
            sm = StructureModel(list(zip(nodes[:-1], nodes[1:])))
            _, ax, _ = plot_structure(sm)
            ax_edges = [
                patch for patch in ax.patches
                if isinstance(patch, plt.patches.FancyArrowPatch)
            ]
            assert len(ax_edges) == num_nodes - 1
Exemplo n.º 15
0
 def test_display_importerror_mpl(self):
     sm = StructureModel([("a", "b")])
     viz = plot_structure(sm, prog="neato")
     with patch.dict("sys.modules", {"matplotlib": None}):
         reload(display)
         with pytest.raises(
                 ImportError,
                 match=
                 r"display_plot_mpl method requires matplotlib installed.",
         ):
             display.display_plot_mpl(viz)
     # NOTE: must reload display again after patch exit
     reload(display)
Exemplo n.º 16
0
    def test_return_types_mpl(self):
        sm = StructureModel([("a", "b")])
        viz = plot_structure(sm, prog="neato")
        d = display.display_plot_mpl(viz)
        assert isinstance(d, tuple)
        assert isinstance(d[0], Figure)
        assert isinstance(d[1], Axes)

        _, ax = plt.subplots()
        d = display.display_plot_mpl(viz, ax=ax)
        assert isinstance(d, tuple)
        assert d[0] is None
        assert isinstance(d[1], Axes)
Exemplo n.º 17
0
    def test_node_attriibutes(self):
        """specific node attributes should be set correctly"""

        sm = StructureModel([("a", "b"), ("b", "c")])
        a_graph = plot_structure(sm)

        default_color = a_graph.get_node("a").attr["color"]
        test_color = "black"

        assert default_color != test_color
        assert all(
            a_graph.get_node(node).attr["color"] == default_color
            for node in a_graph.nodes())

        a_graph = plot_structure(sm,
                                 node_attributes={"a": {
                                     "color": test_color
                                 }})
        assert all(
            a_graph.get_node(node).attr["color"] == default_color
            for node in a_graph.nodes() if node != "a")
        assert a_graph.get_node("a").attr["color"] == test_color
Exemplo n.º 18
0
 def test_node_positions_respected(self, input_positions,
                                   expected_positions):
     """Nodes should be at the positions provided"""
     sm = StructureModel([("a", "b")])
     _, ax, _ = plot_structure(sm, node_positions=input_positions)
     node_coords = [
         list(coord) for coord in ax.collections[0].get_offsets()
     ]
     assert all([
         node_x == exp_x and node_y == exp_y
         for ((exp_x, exp_y),
              (node_x,
               node_y)) in zip(expected_positions, sorted(node_coords))
     ])
Exemplo n.º 19
0
    (ProcessType.var, AbsenteeismLevel.var),
    (InjuryType.var, AbsenteeismLevel.var)
])


structToGraph(weightedGraph = carModel)
# %% markdown [markdown]
# Now visualize:
# %% codecell
from IPython.display import Image
from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE

# Now visualize it:
viz = plot_structure(
    carModel,
    graph_attributes={"scale": "0.5"},
    all_node_attributes=NODE_STYLE.WEAK,
    all_edge_attributes=EDGE_STYLE.WEAK)

filename_demo = graphVizCurrPath + "demo.png"


viz.draw(filename_demo)

Image(filename_demo)


# %% markdown [markdown]
# ## Step 3: Create the Bayesian Model and Fit CPDs
# %% codecell
# Checking the structure is acyclic before passing it to bayesian network:
Exemplo n.º 20
0
genotypes = pd.concat(cultivar, axis=1)
genotype_uniq = genotypes.drop_duplicates()
genotype_uniq.set_axis(['genotype', 'encoding'], axis=1, inplace=True)
genotype_map = dict(zip(genotype_uniq.genotype, genotype_uniq.encoding))

# hardcoded seasons as dict
season_map = dict({'season_4': 0, 'season_6': 1})

with open("~/work/phenophasebbn/bbn/genotype_map.json", "w") as outfile:
    json.dump(genotype_map, outfile)
with open("~/work/phenophasebbn/bbn/season_map.json", "w") as outfile:
    json.dump(season_map, outfile)

# learn structure with NOTEARS, over 1000 iterations,and keep edge weights > 0.95
from causalnex.structure.notears import from_pandas
sm = from_pandas(X=dum_df, max_iter=1000, w_threshold=0.95)
#pickle the structure model
import pickle
# make pickle file binary
smp = open("~/work/phenophasebbn/bbn/nt_sm", "wb")
# dump the pickle; syntax = (model, filename)
pickle.dump(sm, smp)
# close the pickle
smp.close()

#output plot of learned graph
# no need to apply thresholding, since this is taken care of in the sm with w_threshold
from causalnex.plots import plot_structure
viz = plot_structure(sm)
viz.draw("sm_plot.png")
Exemplo n.º 21
0
# * `health` $\longrightarrow$ `G1`
# %% codecell
structureModel.add_edges_from([('health', 'absences'), ('health', 'G1')])

# %% markdown [markdown]
# ## Visualizing the Structure
# %% codecell
structureModel.edges
# %% codecell
structureModel.nodes
# %% codecell
from IPython.display import Image
from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE

viz = plot_structure(structureModel,
                     graph_attributes={"scale": "0.5"},
                     all_node_attributes=NODE_STYLE.WEAK,
                     all_edge_attributes=EDGE_STYLE.WEAK)
filename_first = curPath + "structure_model_first.png"

viz.draw(filename_first)
Image(filename_first)

# %% markdown [markdown]
# ## Learning the Structure
# Can use CausalNex to learn structure model from data, when number of variables grows or domain knowledge does not exist. (Algorithm used is the [NOTEARS algorithm](https://arxiv.org/abs/1803.01422)).
# * NOTE: not always necessary to train / test split because structure learning should be a joint effort between machine learning and domain experts.
#
# First must pre-process the data so the [NOTEARS algorithm](https://arxiv.org/abs/1803.01422) can be used.
#
# ## Preparing the Data for Structure Learning
# %% codecell
Exemplo n.º 22
0
def plot_pretty_structure(
    g: StructureModel,
    edges_to_highlight: Tuple[str, str],
    default_weight: float = 0.2,
    weighted: bool = False,
):
    """
    Utility function to plot our networks in a pretty format

    Args:
        g: Structure model (directed acyclic graph)
        edges_to_highlight: List of edges to highlight in the plots
        default_weight: Default edge weight
        weighted: Whether the graph is weighted

    Returns:
        a styled pygraphgiz graph that can be rendered as an image
    """
    graph_attributes = {
        "splines": "spline",  # I use splies so that we have no overlap
        "ordering": "out",
        "ratio": "fill",  # This is necessary to control the size of the image
        "size":
        "16,9!",  # Set the size of the final image. (this is a typical presentation size)
        "fontcolor": "#FFFFFFD9",
        "fontname": "Helvetica",
        "fontsize": 24,
        "labeljust": "c",
        "labelloc": "c",
        "pad": "1,1",
        "nodesep": 0.8,
        "ranksep": ".5 equally",
    }
    # Making all nodes hexagonal with black coloring
    node_attributes = {
        node: {
            "shape": "hexagon",
            "width": 2.2,
            "height": 2,
            "fillcolor": "#000000",
            "penwidth": "10",
            "color": "#4a90e2d9",
            "fontsize": 24,
            "labelloc": "c",
            "labeljust": "c",
        }
        for node in g.nodes
    }
    # Customising edges
    if weighted:
        edge_weights = [(u, v, w if w else default_weight)
                        for u, v, w in g.edges(data="weight")]
    else:
        edge_weights = [(u, v, default_weight) for u, v in g.edges()]

    edge_attributes = {
        (u, v): {
            "penwidth":
            w * 20 + 2,  # Setting edge thickness
            "weight":
            int(w),  # Higher "weight"s mean shorter edges
            "arrowsize":
            2 - 2.0 * w,  # Avoid too large arrows
            "arrowtail":
            "dot",
            "color":
            "#DF5F00" if ((u, v) in set(edges_to_highlight)) else "#888888",
        }
        for u, v, w in edge_weights
    }
    return plot_structure(
        g,
        prog="dot",
        graph_attributes=graph_attributes,
        node_attributes=node_attributes,
        edge_attributes=edge_attributes,
    )
from causalnex.structure.notears import from_pandas
import time

startTime: float = time.time()

carStructLearned = from_pandas(X=labelEncData)

print(f"Time taken = {clock(startTime = startTime, endTime = time.time())}")

# %% codecell
from IPython.display import Image
from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE

# Now visualize it:
viz = plot_structure(carStructLearned,
                     graph_attributes={"scale": "0.5"},
                     all_node_attributes=NODE_STYLE.WEAK,
                     all_edge_attributes=EDGE_STYLE.WEAK)
filename_carLearned = curPath + "car_learnedStructure.png"

viz.draw(filename_carLearned)
Image(filename_carLearned)

# %% markdown [markdown]
# Getting detailed view into the learned model:
# %% codecell
carStructLearned.adj

# %% codecell
carStructLearned.get_edge_data(u='uses_op', v='absenteeism_level')
# %% codecell
carStructLearned.degree
Exemplo n.º 24
0
    def test_show_labels(self, test_input, expected):
        """Labels should be hidden when show_labels set to False"""
        sm = StructureModel([("a", "b")])
        _, ax, _ = plot_structure(sm, show_labels=test_input)

        assert bool(ax.texts) == expected
Exemplo n.º 25
0
 def test_label_colors(self, test_input, expected):
     """Labels should have color provided to them"""
     sm = StructureModel([("a", "b")])
     _, ax, _ = plot_structure(sm, show_labels=True, label_color=test_input)
     assert all(text.get_color() == expected for text in ax.texts)
Exemplo n.º 26
0
 def test_has_layout(self):
     """Returned AGraph should have an existing layout"""
     sm = StructureModel([("a", "b")])
     a_graph = plot_structure(sm)
     assert a_graph.has_layout
Exemplo n.º 27
0
 def test_title(self, test_input, expected):
     """Title should be set correctly"""
     sm = StructureModel([("a", "b")])
     _, ax, _ = plot_structure(sm, title=test_input)
     assert ax.get_title() == expected
Exemplo n.º 28
0
 def test_install_warning(self, mocked_to_agraph):
     sm = StructureModel()
     with pytest.raises(Warning, match="Pygraphviz not installed"):
         _ = plot_structure(sm)
     mocked_to_agraph.assert_called_once()
Exemplo n.º 29
0
 def test_return_types_ipython(self):
     sm = StructureModel([("a", "b")])
     viz = plot_structure(sm, prog="neato")
     d = display.display_plot_ipython(viz)
     assert isinstance(d, Image)