コード例 #1
0
 def test_zif_value_error(self, graph, wrong_count_zif):
     """
     Test if ValueError raised for unsupported Zero-Inflation Factor for the
     count data type.
     """
     with pytest.raises(ValueError,
                        match="Unsupported zero-inflation factor"):
         sem_generator(
             graph,
             default_type="count",
             distributions={"count": wrong_count_zif},
             seed=42,
         )
コード例 #2
0
    def test_run(self, graph, schema):
        df = sem_generator(
            graph=graph,
            schema=schema,
            default_type="continuous",
            noise_std=1.0,
            n_samples=1000,
            intercept=False,
            seed=12,
        )

        # test binary:
        assert df[0].nunique() == 2
        assert df[2].nunique() == 2

        # test categorical:
        for col in ["1_{}".format(i) for i in range(3)]:
            assert df[col].nunique() == 2
        assert len([x for x in df.columns
                    if isinstance(x, str) and "1_" in x]) == 3

        for col in ["5_{}".format(i) for i in range(5)]:
            assert df[col].nunique() == 2
        assert len([x for x in df.columns
                    if isinstance(x, str) and "5_" in x]) == 5

        # test continuous
        assert df[3].nunique() == 1000
        assert df[4].nunique() == 1000
コード例 #3
0
 def test_missing_default_type(self, graph):
     with pytest.raises(ValueError, match="Unknown default data type"):
         _ = sem_generator(
             graph=graph,
             schema=schema,
             default_type="unknown",
             noise_std=1.0,
             n_samples=1000,
             intercept=False,
             seed=12,
         )
コード例 #4
0
 def test_incorrect_intercept_dist(self, graph):
     with pytest.raises(ValueError, match="Unknown intercept distribution"):
         _ = sem_generator(
             graph=graph,
             schema=None,
             default_type="continuous",
             distributions={"intercept": "unknown"},
             noise_std=2.0,
             n_samples=10,
             intercept=True,
             seed=10,
         )
コード例 #5
0
    def test_only_count(self, graph, zero_inflation_pct):
        df = sem_generator(
            graph,
            default_type="count",
            n_samples=1000,
            distributions={"count": zero_inflation_pct},
            seed=43,
        )
        # count puts a lower bound on the output:
        assert np.all(df.min() >= 0)

        # zero inflation puts a lower bound on the zero-share
        assert np.all((df == 0).mean() >= zero_inflation_pct)
コード例 #6
0
 def test_not_permissible_type(self, graph):
     schema = {
         0: "unknown data type",
     }
     with pytest.raises(ValueError, match="Unknown data type"):
         _ = sem_generator(
             graph=graph,
             schema=schema,
             default_type="continuous",
             noise_std=1.0,
             n_samples=1000,
             intercept=False,
             seed=12,
         )
コード例 #7
0
 def test_missing_cardinality(self, graph):
     schema = {
         0: "categorical",
         1: "categorical:3",
         5: "categorical:5",
     }
     with pytest.raises(ValueError,
                        match="Missing cardinality for categorical"):
         _ = sem_generator(
             graph=graph,
             schema=schema,
             default_type="continuous",
             noise_std=1.0,
             n_samples=1000,
             intercept=False,
             seed=12,
         )
コード例 #8
0
    def test_incorrect_weight_dist(self):
        sm = StructureModel()
        nodes = list(str(x) for x in range(6))
        np.random.shuffle(nodes)
        sm.add_nodes_from(nodes)

        sm.add_weighted_edges_from([("0", "1", None), ("2", "4", None)])

        with pytest.raises(ValueError, match="Unknown weight distribution"):
            _ = sem_generator(
                graph=sm,
                schema=None,
                default_type="continuous",
                distributions={"weight": "unknown"},
                noise_std=2.0,
                n_samples=1000,
                intercept=False,
                seed=10,
            )
コード例 #9
0
    def test_mixed_type_independence(self, seed, n_categories,
                                     weight_distribution,
                                     intercept_distribution):
        """
        Test whether the relation is accurate, implicitly tests sequence of
        nodes.
        """
        np.random.seed(seed)

        sm = StructureModel()
        nodes = list(str(x) for x in range(6))
        np.random.shuffle(nodes)
        sm.add_nodes_from(nodes)
        # binary -> categorical
        sm.add_weighted_edges_from([("0", "1", 10)])
        # binary -> continuous
        sm.add_weighted_edges_from([("2", "4", None)])
        # binary -> count
        sm.add_weighted_edges_from([("2", "6", 100)])

        schema = {
            "0": "binary",
            "1": "categorical:{}".format(n_categories),
            "2": "binary",
            "4": "continuous",
            "5": "categorical:{}".format(n_categories),
            "6": "count",
        }

        df = sem_generator(
            graph=sm,
            schema=schema,
            default_type="continuous",
            distributions={
                "weight": weight_distribution,
                "intercept": intercept_distribution,
                "count": 0.05,
            },
            noise_std=2,
            n_samples=100000,
            intercept=True,
            seed=seed,
        )

        atol = 0.05  # 5% difference bewteen joint & factored!
        # 1. dependent links
        # 0 -> 1 (we look at the class with the highest deviation from uniform
        # to avoid small values)
        c, _ = max(
            [(c, np.abs(df["1_{}".format(c)].mean() - 1 / n_categories))
             for c in range(n_categories)],
            key=operator.itemgetter(1),
        )
        joint_proba, factored_proba = calculate_proba(df, "0",
                                                      "1_{}".format(c))
        assert not np.isclose(joint_proba, factored_proba, rtol=0, atol=atol)
        # 2 -> 4
        assert not np.isclose(
            df["4"].mean(), df["4"][df["2"] == 1].mean(), rtol=0, atol=atol)
        # binary on count
        assert not np.isclose(
            df.loc[df["2"] == 0, "6"].mean(),
            df.loc[df["2"] == 1, "6"].mean(),
            rtol=0,
            atol=atol,
        )

        tol = 0.15  # relative tolerance of +- 15% of the
        # 2. independent links
        # categorical
        c, _ = max(
            [(c, np.abs(df["1_{}".format(c)].mean() - 1 / n_categories))
             for c in range(n_categories)],
            key=operator.itemgetter(1),
        )
        joint_proba, factored_proba = calculate_proba(df, "0",
                                                      "5_{}".format(c))
        assert np.isclose(joint_proba, factored_proba, rtol=tol, atol=0)

        # binary
        joint_proba, factored_proba = calculate_proba(df, "0", "2")
        assert np.isclose(joint_proba, factored_proba, rtol=tol, atol=0)

        # categorical
        c, _ = max(
            [(c, np.abs(df["1_{}".format(c)].mean() - 1 / n_categories))
             for c in range(n_categories)],
            key=operator.itemgetter(1),
        )
        d, _ = max(
            [(d, np.abs(df["5_{}".format(d)].mean() - 1 / n_categories))
             for d in range(n_categories)],
            key=operator.itemgetter(1),
        )
        joint_proba, factored_proba = calculate_proba(df, "1_{}".format(d),
                                                      "5_{}".format(c))
        assert np.isclose(joint_proba, factored_proba, rtol=tol, atol=0)

        # continuous
        # for gaussian distributions, zero variance is equivalent to independence
        assert np.isclose(df[["3", "4"]].corr().values[0, 1], 0, atol=tol)
コード例 #10
0
    def test_graph_not_a_dag(self):
        graph = StructureModel()
        graph.add_edges_from([(0, 1), (1, 2), (2, 0)])

        with pytest.raises(ValueError, match="Provided graph is not a DAG"):
            _ = sem_generator(graph=graph, seed=42)