Пример #1
0
    def test_behaves_same_as_seperate_calls(self, train_data_idx, train_data_discrete):
        bn1 = BayesianNetwork(from_pandas(train_data_idx, w_threshold=0.3))
        bn2 = BayesianNetwork(from_pandas(train_data_idx, w_threshold=0.3))

        bn1.fit_node_states(train_data_discrete).fit_cpds(train_data_discrete)
        bn2.fit_node_states_and_cpds(train_data_discrete)

        assert bn1.edges == bn2.edges
        assert bn1.node_states == bn2.node_states

        cpds1 = bn1.cpds
        cpds2 = bn2.cpds

        assert cpds1.keys() == cpds2.keys()

        for k in cpds1:
            assert cpds1[k].equals(cpds2[k])
    def test_em_algorithm(self):  # pylint: disable=too-many-locals
        """
        Test if `BayesianNetwork` works with EM algorithm.
        We use a naive bayes + parents + an extra node not related to the latent variable.
        """

        # p0   p1  p2
        #   \  |  /
        #      z
        #   /  |  \
        # c0  c1  c2
        # |
        # cc0
        np.random.seed(22)

        data, sm, _, true_lv_values = naive_bayes_plus_parents(
            percentage_not_missing=0.1,
            samples=1000,
            p_z=0.7,
            p_c=0.7,
        )
        data["cc_0"] = np.where(
            np.random.random(len(data)) < 0.5, data["c_0"],
            (data["c_0"] + 1) % 3)
        data.drop(columns=["z"], inplace=True)

        complete_data = data.copy(deep=True)
        complete_data["z"] = true_lv_values

        # Baseline model: the structure of the figure trained with complete data. We try to reproduce it
        complete_bn = BayesianNetwork(
            StructureModel(list(sm.edges) + [("c_0", "cc_0")]))
        complete_bn.fit_node_states_and_cpds(complete_data)

        # BN without latent variable: All `p`s are connected to all `c`s + `c0` ->`cc0`
        sm_no_lv = StructureModel([(f"p_{p}", f"c_{c}") for p in range(3)
                                   for c in range(3)] + [("c_0", "cc_0")])
        bn = BayesianNetwork(sm_no_lv)
        bn.fit_node_states(data)
        bn.fit_cpds(data)

        # TEST 1: cc_0 does not depend on the latent variable so:
        assert np.all(bn.cpds["cc_0"] == complete_bn.cpds["cc_0"])

        # BN with latent variable
        # When we add the latent variable, we add the edges in the image above
        # and remove the connection among `p`s and `c`s
        edges_to_add = list(sm.edges)
        edges_to_remove = [(f"p_{p}", f"c_{c}") for p in range(3)
                           for c in range(3)]
        bn.add_node("z", edges_to_add, edges_to_remove)
        bn.fit_latent_cpds("z", [0, 1, 2], data, stopping_delta=0.001)

        # TEST 2: cc_0 CPD should remain untouched by the EM algorithm
        assert np.all(bn.cpds["cc_0"] == complete_bn.cpds["cc_0"])

        # TEST 3: We should recover the correct CPDs quite accurately
        assert bn.cpds.keys() == complete_bn.cpds.keys()
        assert self.mean_absolute_error(bn.cpds, complete_bn.cpds) < 0.01

        # TEST 4: Inference over recovered CPDs should be also accurate
        eng = InferenceEngine(bn)
        query = eng.query()
        n_rows = complete_data.shape[0]

        for node in query:
            assert (np.abs(query[node][0] -
                           sum(complete_data[node] == 0) / n_rows) < 1e-2)
            assert (np.abs(query[node][1] -
                           sum(complete_data[node] == 1) / n_rows) < 1e-2)

        # TEST 5: Inference using predict and predict_probability functions
        report = classification_report(bn, complete_data, "z")
        _, auc = roc_auc(bn, complete_data, "z")
        complete_report = classification_report(complete_bn, complete_data,
                                                "z")
        _, complete_auc = roc_auc(complete_bn, complete_data, "z")

        for category, metrics in report.items():
            if isinstance(metrics, dict):
                for key, val in metrics.items():
                    assert np.abs(val - complete_report[category][key]) < 1e-2
            else:
                assert np.abs(metrics - complete_report[category]) < 1e-2

        assert np.abs(auc - complete_auc) < 1e-2