def test_fit_invalid_lv_name(self, lv_name):
        """An error should be raised if the latent variable is of an invalid type"""

        with pytest.raises(
                ValueError,
                match=r"Invalid latent variable name *",
        ):
            df, sm, _, _ = naive_bayes_plus_parents()
            sm = StructureModel(list(sm.edges))
            bn = BayesianNetwork(sm)
            bn.fit_latent_cpds(lv_name, [0, 1, 2], df)
    def test_fit_lv_not_added(self):
        """An error should be raised if the latent variable is not added to the network yet"""

        with pytest.raises(
                ValueError,
                match=r"Latent variable 'd' not added to the network",
        ):
            df, sm, _, _ = naive_bayes_plus_parents()
            sm = StructureModel(list(sm.edges))
            bn = BayesianNetwork(sm)
            bn.fit_latent_cpds("d", [0, 1, 2], df)
    def test_fit_invalid_lv_states(self, lv_states):
        """An error should be raised if the latent variable has invalid states"""

        with pytest.raises(
                ValueError,
                match="Latent variable 'd' contains no states",
        ):
            df, sm, _, _ = naive_bayes_plus_parents()
            sm = StructureModel(list(sm.edges))
            bn = BayesianNetwork(sm)
            bn.add_node("d", [("z", "d")], [])
            bn.fit_latent_cpds("d", lv_states, df)
def get_avg_auc_lvs(
    df: pd.DataFrame,
    bn: BayesianNetwork,
    lv_states: List,
    n_splits: int = 5,
    seed: int = 2021,
    markov_blanket: bool = False,
    n_cpus: int = multiprocessing.cpu_count() - 1,
) -> float:
    """
    Utility function to compute AUC using only the parent nodes

    Args:
        df: Input dataframe
        bn: Bayesian network
        lv_states: the states the LV can assume
        n_splits: Number of cross-validation folds
        seed: Random seed number
        markov_blanket: Whether we predict only using the Markov blanket
        n_cpus: Number of CPU cores to use

    Returns:
        Average AUC
    """
    cv = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    total_auc = 0

    for fold, (train_idx, test_idx) in enumerate(cv.split(df)):
        t0 = time()
        train_df = df.loc[train_idx, :]
        test_df = df.loc[test_idx, :]
        bn.fit_latent_cpds("LV", lv_states, train_df, n_runs=30)
        chunks = [[bn, test_df, target, markov_blanket] for target in bn.nodes
                  if target != "LV"]
        with multiprocessing.Pool(n_cpus) as p:
            result = p.starmap(_compute_auc_lv_stub, chunks)

        total_auc += sum(result) / (len(bn.nodes) - 1)
        print(
            f"Processing fold {fold} using {n_cpus} cores takes {time() - t0} seconds"
        )

    return total_auc / n_splits
    def test_em_algorithm(self):  # pylint: disable=too-many-locals
        """
        Test if `BayesianNetwork` works with EM algorithm.
        We use a naive bayes + parents + an extra node not related to the latent variable.
        """

        # p0   p1  p2
        #   \  |  /
        #      z
        #   /  |  \
        # c0  c1  c2
        # |
        # cc0
        np.random.seed(22)

        data, sm, _, true_lv_values = naive_bayes_plus_parents(
            percentage_not_missing=0.1,
            samples=1000,
            p_z=0.7,
            p_c=0.7,
        )
        data["cc_0"] = np.where(
            np.random.random(len(data)) < 0.5, data["c_0"],
            (data["c_0"] + 1) % 3)
        data.drop(columns=["z"], inplace=True)

        complete_data = data.copy(deep=True)
        complete_data["z"] = true_lv_values

        # Baseline model: the structure of the figure trained with complete data. We try to reproduce it
        complete_bn = BayesianNetwork(
            StructureModel(list(sm.edges) + [("c_0", "cc_0")]))
        complete_bn.fit_node_states_and_cpds(complete_data)

        # BN without latent variable: All `p`s are connected to all `c`s + `c0` ->`cc0`
        sm_no_lv = StructureModel([(f"p_{p}", f"c_{c}") for p in range(3)
                                   for c in range(3)] + [("c_0", "cc_0")])
        bn = BayesianNetwork(sm_no_lv)
        bn.fit_node_states(data)
        bn.fit_cpds(data)

        # TEST 1: cc_0 does not depend on the latent variable so:
        assert np.all(bn.cpds["cc_0"] == complete_bn.cpds["cc_0"])

        # BN with latent variable
        # When we add the latent variable, we add the edges in the image above
        # and remove the connection among `p`s and `c`s
        edges_to_add = list(sm.edges)
        edges_to_remove = [(f"p_{p}", f"c_{c}") for p in range(3)
                           for c in range(3)]
        bn.add_node("z", edges_to_add, edges_to_remove)
        bn.fit_latent_cpds("z", [0, 1, 2], data, stopping_delta=0.001)

        # TEST 2: cc_0 CPD should remain untouched by the EM algorithm
        assert np.all(bn.cpds["cc_0"] == complete_bn.cpds["cc_0"])

        # TEST 3: We should recover the correct CPDs quite accurately
        assert bn.cpds.keys() == complete_bn.cpds.keys()
        assert self.mean_absolute_error(bn.cpds, complete_bn.cpds) < 0.01

        # TEST 4: Inference over recovered CPDs should be also accurate
        eng = InferenceEngine(bn)
        query = eng.query()
        n_rows = complete_data.shape[0]

        for node in query:
            assert (np.abs(query[node][0] -
                           sum(complete_data[node] == 0) / n_rows) < 1e-2)
            assert (np.abs(query[node][1] -
                           sum(complete_data[node] == 1) / n_rows) < 1e-2)

        # TEST 5: Inference using predict and predict_probability functions
        report = classification_report(bn, complete_data, "z")
        _, auc = roc_auc(bn, complete_data, "z")
        complete_report = classification_report(complete_bn, complete_data,
                                                "z")
        _, complete_auc = roc_auc(complete_bn, complete_data, "z")

        for category, metrics in report.items():
            if isinstance(metrics, dict):
                for key, val in metrics.items():
                    assert np.abs(val - complete_report[category][key]) < 1e-2
            else:
                assert np.abs(metrics - complete_report[category]) < 1e-2

        assert np.abs(auc - complete_auc) < 1e-2