Exemple #1
0
    def test_f1score_generated_poisson(self):
        """ Poisson strucutre learned should have good f1 score """
        np.random.seed(10)
        sm = generate_structure(5, 3.0)
        df = generate_count_dataframe(
            sm, 1000, intercept=False, zero_inflation_factor=0.0, seed=10
        )
        df = np.asarray(df)

        dist_type_schema = {i: "poiss" for i in range(df.shape[1])}
        sm_fitted = from_numpy(
            df,
            dist_type_schema=dist_type_schema,
            lasso_beta=0.1,
            ridge_beta=0.0,
            w_threshold=0.1,
            use_bias=False,
        )

        right_edges = sm.edges
        n_predictions_made = len(sm_fitted.edges)
        n_correct_predictions = len(set(sm_fitted.edges).intersection(set(right_edges)))
        n_relevant_predictions = len(right_edges)

        precision = n_correct_predictions / n_predictions_made
        recall = n_correct_predictions / n_relevant_predictions
        f1_score = 2 * (precision * recall) / (precision + recall)

        assert f1_score > 0.7
Exemple #2
0
 def test_zero_lambda(self):
     """
     A wrong initialisation could lead to counts always being zero if they dont
     have parents.
     """
     graph = StructureModel()
     graph.add_nodes_from(list(range(20)))
     df = generate_count_dataframe(graph, 10000)
     assert not np.any(df.mean() == 0)
Exemple #3
0
    def test_dataframe(self, graph, intercept, seed, kernel,
                       zero_inflation_factor):
        """
        Tests equivalence of dataframe wrapper
        """
        data = generate_count_dataframe(
            graph,
            100,
            zero_inflation_factor=zero_inflation_factor,
            seed=seed,
            intercept=intercept,
            kernel=kernel,
        )
        df = generate_count_dataframe(
            graph,
            100,
            zero_inflation_factor=zero_inflation_factor,
            seed=seed,
            intercept=intercept,
            kernel=kernel,
        )

        assert np.array_equal(data, df[list(graph.nodes())].values)