Ejemplo n.º 1
0
def test_product_with_single_column_table():
    single_table = Table({"A": 3, "B": 4, "C": 7}, names=["Y1"])
    table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"])

    # without common names
    table3 = table1 * single_table
    assert all(compare(table3.names, ["X1", "X2", "X3", "X4", "Y1"]))

    # check probabilites
    assert table3[("a", "y", 2, 33, "B")] == 24

    table3 = single_table * table1
    assert all(compare(table3.names, ["Y1", "X1", "X2", "X3", "X4"]))

    # check probabilites
    assert table3[("B", "a", "y", 2, 33)] == 24

    # with common names
    single_table = Table({"x": 3, "y": 4}, names=["X2"])
    table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"])

    table3 = table1 * single_table
    assert all(compare(table3.names, ["X1", "X2", "X3", "X4"]))

    # check probabilites
    assert table3["a", "x", 2, 44] == 12
    assert table3["a", "y", 2, 44] == 32

    table3 = single_table * table1
    assert all(compare(table3.names, ["X2", "X1", "X3", "X4"]))
Ejemplo n.º 2
0
def test_table_of_table():
    sample_1 = {
        ("a", "x", 1, 33): 1,
        ("a", "x", 2, 33): 2,
        ("a", "x", 1, 44): 3,
        ("a", "x", 2, 44): 4,
        ("a", "y", 1, 33): 5,
        ("a", "y", 2, 33): 6,
        ("a", "y", 1, 44): 7,
        ("a", "y", 2, 44): 8,
        ("b", "x", 1, 33): 9,
        ("b", "x", 2, 33): 10,
        ("b", "x", 1, 44): 11,
        ("b", "x", 2, 44): 12,
        ("b", "y", 1, 33): 13,
        # ("b", "y", 2, 33): 14,
        ("b", "y", 1, 44): 15,
        ("b", "y", 2, 44): 16,
    }
    t1 = Table(sample_1, ["X1", "X2", "X3", "X4"])
    t2 = Table(sample_1, ["X1", "X2", "X3", "X4"])
    t3 = Table(sample_1, ["X1", "X2", "X3", "X4"])
    t4 = Table(sample_1, ["X1", "X2", "X3", "X4"])

    sample_2 = {"t1": t1, "t2": t2, "t3": t3, "t4": t4}
    tt = Table(sample_2, ["t"])
    assert tt["t2"] == t2
Ejemplo n.º 3
0
def test_conditional_on_table():

    table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
    con_table1 = table1.condition_on("X2")
    assert all(compare(con_table1.names, ["X2"]))
    for x2 in con_table1.keys():
        child_table = con_table1[x2]
        assert all(compare(child_table.names, ["X1", "X3", "X4"]))

    child_table1 = con_table1["x"]
    child_table2 = con_table1["y"]
    assert child_table1["a", 1, 33] == 1 / 52
    assert child_table2["a", 1, 33] == 5 / 84
    assert child_table1["a", 1, 44] == 3 / 52
    assert child_table2["a", 1, 44] == 7 / 84
    assert child_table1["b", 1, 33] == 9 / 52
    assert child_table2["b", 1, 33] == 13 / 84
    assert child_table1["b", 1, 44] == 11 / 52
    assert child_table2["b", 1, 44] == 15 / 84
    assert child_table1["b", 2, 44] == 12 / 52
    assert child_table2["b", 2, 33] == 14 / 84

    # combined indexing
    assert con_table1["x"]["a", 1, 33] == 1 / 52
    assert con_table1["y"]["a", 1, 33] == 5 / 84
Ejemplo n.º 4
0
def test_total_table():
    table1 = Table({"a": 3, "b": 4, "c": 5}, names=["X1"])
    assert table1.total() == 12

    table1 = Table(
        {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6},
        names=["X1", "X2"],
    )
    assert table1.total() == 20

    samples = {
        ("a", "x", 1, 33): 1,
        ("a", "x", 2, 33): 2,
        ("a", "x", 1, 44): 3,
        ("a", "x", 2, 44): 4,
        ("a", "y", 1, 33): 5,
        ("a", "y", 2, 33): 6,
        ("a", "y", 1, 44): 7,
        ("a", "y", 2, 44): 8,
        ("b", "x", 1, 33): 9,
        ("b", "x", 2, 33): 10,
        ("b", "x", 1, 44): 11,
        ("b", "x", 2, 44): 12,
        ("b", "y", 1, 33): 13,
        ("b", "y", 2, 33): 14,
        ("b", "y", 1, 44): 15,
        ("b", "y", 2, 44): 16,
    }
    table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
    con_1 = table1.condition_on("X1", normalise=False)
    totals = con_1.total()
    assert totals[("a",)] == 36
    assert totals[("b",)] == 100
Ejemplo n.º 5
0
def test_product_with_two_common_vars_table():
    table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"])
    table2 = Table(sample_2, names=["X3", "X5", "X6", "X2"])

    table3 = table1 * table2
    assert all(compare(table3.names, ["X1", "X2", "X3", "X4", "X5", "X6"]))

    # check probabilites
    assert table3[("a", "y", 2, 33, "high", "under")] == 150
    # check the case that the right does not have the common
    assert table3[("a", "y", 2, 33, "low", "under")] is None
    # check the case that the left does not have the common
    assert table3[("b", "y", 2, 33, "high", "under")] is None
Ejemplo n.º 6
0
def test_product_with_no_common_vars_table():

    table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"])
    table2 = Table(sample_2, names=["Y1", "Y2", "Y3", "Y4"])

    table3 = table1 * table2
    assert all(
        compare(table3.names,
                ["X1", "X2", "X3", "X4", "Y1", "Y2", "Y3", "Y4"]))

    # check probabilites
    assert table3[("a", "x", 1, 33, 2, "high", "normal", "x")] == 10

    assert table3[("b", "x", 1, 44, 1, "low", "over", "y")] == 253
Ejemplo n.º 7
0
def test_product_with_one_common_var_table():

    table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"])
    table2 = Table(sample_2, names=["X3", "X5", "X6", "X7"])

    table3 = table1 * table2
    assert all(
        compare(table3.names, ["X1", "X2", "X3", "X4", "X5", "X6", "X7"]))

    # check probabilites
    assert table3["a", "x", 1, 33, "high", "normal", "x"] == 2
    # check the case that the right does not have the common
    assert table3[("b", "y", 2, 44, "high", "over", "y")] is None
    # check the case that the left does not have the common
    assert table3[("b", "y", 2, 33, "high", "normal", "y")] is None
Ejemplo n.º 8
0
def test_conditional_on_conditional_table():
    table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
    con_1 = table1.condition_on("X3")
    con_2 = con_1.condition_on("X1")
    # Note: since we first condition on X3
    # and then on X1, the order of keys
    # is as X1, X3. Otherwise, it could be
    # the inverse
    assert con_2["a", 1]["x", 33] == approx(1 / 16)
    assert con_2["a", 2]["x", 33] == approx(2 / 20)
    assert con_2["a", 1]["x", 44] == approx(3 / 16)
    assert con_2["a", 2]["x", 44] == approx(4 / 20)
    assert con_2["b", 1]["x", 33] == approx(9 / 48)
    assert con_2["b", 2]["x", 33] == approx(10 / 52)
    assert con_2["b", 1]["y", 33] == approx(13 / 48)
    assert con_2["b", 2]["y", 44] == approx(16 / 52)
    # Try the inverse conditioning
    con_1 = table1.condition_on("X1")
    con_2 = con_1.condition_on("X3")
    assert con_2[1, "a"]["x", 33] == approx(1 / 16)
    assert con_2[2, "a"]["x", 44] == approx(4 / 20)
    assert con_2[1, "b"]["x", 33] == approx(9 / 48)
    assert con_2[2, "b"]["x", 33] == approx(10 / 52)
    # we can use 'get' to make it order-agnostic
    assert con_2.get(X1="a", X3=1)["x", 33] == approx(1 / 16)
    #
    #
    con_1 = table1.condition_on("X3", "X4")
    con_2 = con_1.condition_on("X1")
    assert con_2["a", 1, 33]["x"] == approx(1 / 6)
    assert con_2["a", 2, 33]["x"] == approx(2 / 8)
    assert con_2["a", 1, 44]["x"] == approx(3 / 10)
    assert con_2["a", 2, 44]["x"] == approx(4 / 12)
    assert con_2["b", 1, 44]["x"] == approx(11 / 26)
    assert con_2["b", 2, 44]["x"] == approx(12 / 28)
    assert con_2["b", 1, 44]["y"] == approx(15 / 26)
    assert con_2["b", 2, 44]["y"] == approx(16 / 28)

    con_1 = table1.condition_on("X4")
    con_2 = con_1.condition_on("X1", "X3")
    assert con_2["a", 1, 33]["x"] == approx(1 / 6)
    assert con_2["a", 2, 33]["x"] == approx(2 / 8)
    assert con_2["a", 1, 44]["x"] == approx(3 / 10)
    assert con_2["a", 2, 44]["x"] == approx(4 / 12)
    assert con_2["b", 1, 44]["x"] == approx(11 / 26)
    assert con_2["b", 2, 44]["x"] == approx(12 / 28)
    assert con_2["b", 1, 44]["y"] == approx(15 / 26)
    assert con_2["b", 2, 44]["y"] == approx(16 / 28)
    # change the order
    con_1 = table1.condition_on("X3")
    con_2 = con_1.condition_on("X1", "X4")
    assert con_2["a", 33, 1]["x"] == approx(1 / 6)
    assert con_2["a", 33, 2]["x"] == approx(2 / 8)
    assert con_2["a", 44, 1]["x"] == approx(3 / 10)
    assert con_2["a", 44, 2]["x"] == approx(4 / 12)
    assert con_2["b", 44, 1]["x"] == approx(11 / 26)
    assert con_2["b", 44, 2]["x"] == approx(12 / 28)
    assert con_2["b", 44, 1]["y"] == approx(15 / 26)
    assert con_2["b", 44, 2]["y"] == approx(16 / 28)
Ejemplo n.º 9
0
def test_constructor_table():

    table = Table({"one": 1, "two": 2, "three": 3}, names=["Y1"])
    assert table["one"] == 1
    assert table["two"] == 2
    assert table["three"] == 3

    table = Table([("two", 2), ("one", 1), ("three", 3)])
    assert table["one"] == 1
    assert table["two"] == 2
    assert table["three"] == 3

    samples = [1, 2, 2, 3, 3, 3]
    counter = Counter(samples)
    table = Table(counter)
    assert table[1] == 1
    assert table[2] == 2
    assert table[3] == 3

    table = Table(zip(["one", "two", "three"], [1, 2, 3]))
    assert table["one"] == 1
    assert table["two"] == 2
    assert table["three"] == 3

    # numeric key
    table = Table({1: 1, 2: 2, 3: 3})
    assert table[1] == 1
    assert table[2] == 2
    assert table[3] == 3

    table = Table({1.1: 1, 2.2: 2, 3.3: 3})
    assert table[1.1] == 1
    assert table[2.2] == 2
    assert table[3.3] == 3
Ejemplo n.º 10
0
def test_reduce_by_name_table():

    table = Table(samples)
    reduced_table = table.reduce(X2="y")
    assert reduced_table.columns.size == 3
    assert all(compare(reduced_table.names, ["X1", "X3", "X4"]))
    assert reduced_table[("a", 1, 33)] == 5
    assert reduced_table[("b", 2, 44)] == 16

    reduced_table = table.reduce(X2="y", X3=1)
    assert reduced_table.columns.size == 2
    assert all(compare(reduced_table.names, ["X1", "X4"]))
    assert reduced_table[("a", 33)] == 5
    assert reduced_table[("b", 44)] == 15

    reduced_table = table.reduce(X1="b", X3=1, X4=44)
    assert reduced_table.columns.size == 1
    assert all(compare(reduced_table.names, ["X2"]))
    assert reduced_table["x"] == 11
    assert reduced_table["y"] == 15

    table = Table(samples, names=["Y", "Z", "W", "X"])

    reduced_table = table.reduce(Z="y")
    assert reduced_table.columns.size == 3
    assert all(compare(reduced_table.names, ["Y", "W", "X"]))
    assert reduced_table[("a", 1, 33)] == 5
    assert reduced_table[("b", 2, 44)] == 16

    reduced_table = table.reduce(Z="y", W=1)
    assert reduced_table.columns.size == 2
    assert all(compare(reduced_table.names, ["Y", "X"]))
    assert reduced_table[("a", 33)] == 5
    assert reduced_table[("b", 44)] == 15

    reduced_table = table.reduce(Y="b", W=1, X=44)
    assert reduced_table.columns.size == 1
    assert all(compare(reduced_table.names, ["Z"]))
    assert reduced_table["x"] == 11
    assert reduced_table["y"] == 15
Ejemplo n.º 11
0
def test_reduce_by_name_on_conditioned_table():

    table = Table(samples, names=["X1", "X2", "X3", "X4"])
    con_1 = table.condition_on("X1")
    reduced_table = con_1.reduce(X2="y")
    assert reduced_table.columns.size == 1
    assert all(compare(reduced_table.columns.children_names, ["X3", "X4"]))
    assert reduced_table["a"][1, 33] == 5 / 36
    assert reduced_table["b"][(2, 44)] == 16 / 100

    reduced_table = con_1.reduce(X2="y", X3=1)
    assert reduced_table.columns.size == 1
    assert all(compare(reduced_table.columns.children_names, ["X4"]))
    assert reduced_table["a"][33] == 5 / 36
    assert reduced_table["b"][44] == 15 / 100

    con_1 = table.condition_on("X1", "X3")
    reduced_table = con_1.reduce(X2="y")
    assert reduced_table.columns.size == 2
    assert all(compare(reduced_table.columns.children_names, ["X4"]))
    assert reduced_table["a", 1][33] == 5 / 16
    assert reduced_table["b", 2][44] == 16 / 52
Ejemplo n.º 12
0
def test_getitem_table():
    # One column
    table = Table({"one": 1, "two": 2, "three": 3}, names=["Y1"])
    assert table["one"] == 1
    assert table["two"] == 2
    assert table["three"] == 3
    assert table.get(Y1="one") == 1
    assert table.get(Y1="two") == 2
    assert table.get(Y1="three") == 3

    # Two columns
    table = Table(
        {("one", "Red"): 1, ("two", "Green"): 2, ("three", "Blue"): 3},
        names=["Y1", "Y2"],
    )
    # Not following the order must get None
    assert table["Red", "one"] is None
    #
    assert table["one", "Red"] == 1
    assert table["two", "Green"] == 2
    assert table["three", "Blue"] == 3
    assert table.get("one", "Red") == 1
    assert table.get("one", Y2="Red") == 1
    assert table.get("Red", Y1="one") == 1
    assert table.get(Y1="one", Y2="Red") == 1
    assert table.get("two", "Green") == 2
    assert table.get("two", Y2="Green") == 2
    assert table.get("Green", Y1="two") == 2
    assert table.get(Y1="two", Y2="Green") == 2
    assert table.get("three", "Blue") == 3
    assert table.get("three", Y2="Blue") == 3
    assert table.get("Blue", Y1="three") == 3
    assert table.get(Y1="three", Y2="Blue") == 3

    # Three columns
    table = Table(
        {
            ("one", "Red", 11): 1,
            ("two", "Green", 22): 2,
            ("three", "Blue", 33): 3,
        },
        names=["Y1", "Y2", "Y3"],
    )
    # Not following the order must get None
    assert table["Red", "one", 11] is None
    assert table[11, "Red", "one"] is None
    assert table["one", 11, "Red"] is None
    #
    assert table["one", "Red", 11] == 1
    assert table["two", "Green", 22] == 2
    assert table["three", "Blue", 33] == 3
    assert table.get("one", "Red", 11) == 1
    assert table.get("one", Y2="Red", Y3=11) == 1
    assert table.get("Red", Y1="one", Y3=11) == 1
    assert table.get(Y1="one", Y2="Red", Y3=11) == 1
    assert table.get(11, Y1="one", Y2="Red") == 1
    assert table.get(11, Y2="Red", Y1="one") == 1
    assert table.get("two", "Green", 22) == 2
    assert table.get("two", Y2="Green", Y3=22) == 2
    assert table.get("Green", Y1="two", Y3=22) == 2
    assert table.get(Y1="two", Y2="Green", Y3=22) == 2
    assert table.get("three", "Blue", 33) == 3
    assert table.get("three", Y2="Blue", Y3=33) == 3
    assert table.get("Blue", Y1="three", Y3=33) == 3
    assert table.get(Y1="three", Y2="Blue", Y3=33) == 3
Ejemplo n.º 13
0
def test_normalise_table():
    table1 = Table({"a": 3, "b": 4, "c": 5}, names=["X1"])
    table1.normalise()
    assert table1["a"] == 3 / 12
    assert table1["b"] == 4 / 12

    table1 = Table(
        {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6},
        names=["X1", "X2"],
    )
    table1.normalise()
    assert table1["a", "x"] == 4 / 20
    assert table1["a", "y"] == 4 / 20
    assert table1["b", "x"] == 6 / 20
    assert table1["b", "y"] == 6 / 20

    samples = {
        ("a", "x", 1, 33): 1,
        ("a", "x", 2, 33): 2,
        ("a", "x", 1, 44): 3,
        ("a", "x", 2, 44): 4,
        ("a", "y", 1, 33): 5,
        ("a", "y", 2, 33): 6,
        ("a", "y", 1, 44): 7,
        ("a", "y", 2, 44): 8,
        ("b", "x", 1, 33): 9,
        ("b", "x", 2, 33): 10,
        ("b", "x", 1, 44): 11,
        ("b", "x", 2, 44): 12,
        ("b", "y", 1, 33): 13,
        ("b", "y", 2, 33): 14,
        ("b", "y", 1, 44): 15,
        ("b", "y", 2, 44): 16,
    }
    table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
    con_1 = table1.condition_on("X1")
    con_1.normalise()
    assert con_1["a"]["x", 1, 33] == 1 / 36
    assert con_1["a"]["y", 1, 33] == 5 / 36
    assert con_1["a"]["y", 2, 33] == 6 / 36
    assert con_1["a"]["y", 2, 44] == 8 / 36
    assert con_1["b"]["x", 1, 33] == 9 / 100
    assert con_1["b"]["y", 1, 33] == 13 / 100
    assert con_1["b"]["y", 2, 44] == 16 / 100

    con_1 = table1.condition_on("X2")
    con_1.normalise()
    assert con_1["x"]["a", 1, 33] == 1 / 52
    assert con_1["x"]["a", 2, 44] == 4 / 52
    assert con_1["y"]["b", 1, 33] == 13 / 84

    con_1 = table1.condition_on("X1", "X2")
    con_1.normalise()
    assert con_1["a", "x"][1, 33] == 1 / 10
    assert con_1["a", "x"][2, 44] == 4 / 10
    assert con_1["b", "y"][2, 33] == 14 / 58
Ejemplo n.º 14
0
def test_product_with_table_of_table_with_no_common_table():
    table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"])
    table2 = Table({"one": 1, "two": 2, "three": 3}, names=["X1"])
    con_1 = table1.condition_on("X1")
    # P(X2, X3, X4 | X1) * P(X1) -> P(X2, X3, X4, X1)
    product_1 = con_1 * table2
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X2", "X3", "X4", "X1"]))
    # P(X1) * P(X2, X3, X4 | X1) -> P(X1, X2, X3, X4)
    product_1 = table2 * con_1
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1", "X2", "X3", "X4"]))

    table2 = Table({
        ("one", "x"): 1,
        ("two", "x"): 2,
        ("three", "y"): 3
    },
                   names=["X1", "X2"])

    con_1 = table1.condition_on("X1", "X2")
    # P(X3, X4 | X1, X2) * P(X1, X2) -> P(X3, X4 , X1, X2)
    product_1 = con_1 * table2
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X3", "X4", "X1", "X2"]))
    # P(X1, X2) * P(X3, X4 | X1, X2) -> P( X1, X2, X3, X4)
    product_1 = table2 * con_1
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1", "X2", "X3", "X4"]))

    con_2 = table2.condition_on("X1")
    # P(X3, X4 | X1, X2) * P(X2 | X1) -> P(X3, X4, X2 | X1)
    product_1 = con_1 * con_2
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1"]))
    assert all(compare(product_1.columns.children_names, ["X3", "X4", "X2"]))
    # P(X2 | X1) * P(X3, X4 | X1, X2) -> P(X2, X3, X4 | X1)
    product_1 = con_2 * con_1
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1"]))
    assert all(compare(product_1.columns.children_names, ["X2", "X3", "X4"]))

    table3 = Table(
        {
            ("one", "x", 1): 1,
            ("two", "x", 1): 2,
            ("three", "y", 2): 3
        },
        names=["X1", "X2", "X3"],
    )

    con_1 = table1.condition_on("X1", "X2", "X3")
    # P(X4 | X1, X2, X3) * P(X1, X2, X3) -> P(X4, X1, X2, X3)
    product_1 = con_1 * table3
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X4", "X1", "X2", "X3"]))
    # P(X1, X2) * P(X3, X4 | X1, X2) -> P( X1, X2, X3, X4)
    product_1 = table3 * con_1
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1", "X2", "X3", "X4"]))

    con_2 = table3.condition_on("X1")
    # P(X4 | X1, X2, X3) * P(X2, X3 | X1) -> P(X4, X2, X3 | X1)
    product_1 = con_1 * con_2
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1"]))
    assert all(compare(product_1.columns.children_names, ["X4", "X2", "X3"]))
    # P(X2, X3 | X1) * P(X4 | X1, X2, X3) -> P(X2, X3, X4 | X1)
    product_1 = con_2 * con_1
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1"]))
    assert all(compare(product_1.columns.children_names, ["X2", "X3", "X4"]))

    con_2 = table3.condition_on("X1", "X2")
    # P(X4 | X1, X2, X3) * P(X3 | X1, X2) -> P(X4, X3 | X1, X2)
    product_1 = con_1 * con_2
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1", "X2"]))
    assert all(compare(product_1.columns.children_names, ["X4", "X3"]))
    # P(X3 | X1, X2) * P(X4 | X1, X2, X3) -> P(X3, X4 | X1, X2)
    product_1 = con_2 * con_1
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X1", "X2"]))
    assert all(compare(product_1.columns.children_names, ["X3", "X4"]))

    con_2 = table3.condition_on("X2", "X3")
    # P(X4 | X1, X2, X3) * P(X1 | X2, X3) -> P(X4, X1 | X2, X3)
    product_1 = con_1 * con_2
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X2", "X3"]))
    assert all(compare(product_1.columns.children_names, ["X4", "X1"]))
    # P(X1 | X2, X3) * P(X4 | X1, X2, X3) -> P(X1, X4 | X2, X3)
    product_1 = con_2 * con_1
    assert len(product_1) == 0
    assert all(compare(product_1.names, ["X2", "X3"]))
    assert all(compare(product_1.columns.children_names, ["X1", "X4"]))
Ejemplo n.º 15
0
def test_product_with_table_of_table_column_table():

    table1 = Table(sample_3, names=["X1", "X2", "X3", "X4", "X5"])

    def assert_all(table1, table2):
        for key1 in table1:
            key2_dict = table1.columns.named_key(key1)
            assert table1[key1] == approx(table2.get(**key2_dict))

    table1.normalise()

    con_1 = table1.condition_on("X1")
    table2 = table1.marginal("X2", "X3", "X4", "X5")
    # P(X2, X3, X4, X5 | X1) * P(X1) -> P(X2, X3, X4, X5, X1)
    product_1 = con_1 * table2
    assert_all(product_1, table1)
    # P(X1) * P(X2, X3, X4, X5 | X1) -> P(X1, X2, X3, X4, X5)
    product_1 = table2 * con_1
    assert_all(product_1, table1)

    con_1 = table1.condition_on("X2")
    table2 = table1.marginal("X1", "X3", "X4", "X5")
    # P(X1, X3, X4, X5 | X2) * P(X1) -> P(X1, X3, X4, X5, X2)
    product_1 = con_1 * table2
    assert_all(product_1, table1)
    # P(X1) * P(X1, X3, X4, X5 | X2) -> P(X1, X3, X4, X5, X1)
    product_1 = table2 * con_1
    assert_all(product_1, table1)

    con_1 = table1.condition_on("X5")
    table2 = table1.marginal("X1", "X2", "X3", "X4")
    # P(X1, X2, X3, X4 | X5) * P(X5) -> P(X1, X2, X3, X4, X5)
    product_1 = con_1 * table2
    assert_all(product_1, table1)
    # P(X5) * P(X1, X2, X3, X4 | X5) -> P(X5, X1, X2, X3, X4)
    product_1 = table2 * con_1
    assert_all(product_1, table1)

    con_1 = table1.condition_on("X1", "X2")
    table2 = table1.marginal("X3", "X4", "X5")
    # P(X3, X4, X5 | X1, X2) * P(X1, X2) -> P(X3, X4, X5, X1, X2)
    product_1 = con_1 * table2
    assert_all(product_1, table1)
    # P(X1, X2) * P(X3, X4, X5 | X1, X2) -> P(X1, X2, X3, X4, X5)
    product_1 = table2 * con_1
    assert_all(product_1, table1)

    con_1 = table1.condition_on("X1", "X3")
    table2 = table1.marginal("X2", "X4", "X5")
    # P(X2, X4, X5 | X1, X3) * P(X1, X3) -> P(X2, X4, X5, X1, X3)
    product_1 = con_1 * table2
    assert_all(product_1, table1)
    # P(X1, X3) * P(X2, X4, X5 | X1, X3) -> P(X1, X3, X2, X4, X5)
    product_1 = table2 * con_1
    assert_all(product_1, table1)

    con_1 = table1.condition_on("X2", "X3")
    table2 = table1.marginal("X1", "X4", "X5")
    # P(X1, X4, X5 | X2, X3) * P(X2, X3) -> P(X1, X4, X5, X2, X3)
    product_1 = con_1 * table2
    assert_all(product_1, table1)
    # P(X2, X3) * P(X1, X4, X5 | X2, X3) -> P(X2, X3, X1, X4, X5)
    product_1 = table2 * con_1
    assert_all(product_1, table1)

    con_1 = table1.marginal("X4", "X5").condition_on("X2", "X3")
    table2 = table1.condition_on("X1", "X2", "X3")
    # P(X1 | X2, X3) * P(X4, X5 | X1, X2, X3) -> P(X1, X4, X5| X2, X3)
    product_1 = con_1 * table2
    assert_all(product_1, table1)
    # P(X4, X5 | X1, X2, X3) * P(X1 | X2, X3) -> P(X4, X5, X1| X2, X3)
    product_1 = table2 * con_1
    assert_all(product_1, table1)
Ejemplo n.º 16
0
def test_marginals_names_exception_table():
    # Wrong column name
    with pytest.raises(ValueError):
        samples = {"a": 3, "b": 4, "c": 5}
        table = Table(samples)
        table.marginal("X1")
    # Wrong column name
    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table = Table(samples)
        table.marginal("X0")
    # Wrong column name
    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table = Table(samples)
        table.marginal("X3")
    # Wrong column name
    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table = Table(samples)
        table2 = table.marginal("X1")
        table2.marginal("X1")
    # Wrong column name
    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table = Table(samples, names=["Y", "Z"])
        table.marginal("X1")
    # Wrong column name
    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table = Table(samples, names=["Y", "Z"])
        table.marginal("X1")
    # Wrong column name
    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table = Table(samples, names=["Y", "Z"])
        table2 = table.marginal("Y")
        table2.marginal("Y")

    # Marginalize over all columns
    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table = Table(samples, names=["Y", "Z"])
        table2 = table.marginal("Y", "Z")
Ejemplo n.º 17
0
def test_marginals_names_table():
    samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
    table = Table(samples)

    table2 = table.marginal("X1")
    assert all(compare(table2.names, ["X2"]))

    table2 = table.marginal("X2")
    assert all(compare(table2.names, ["X1"]))
    #
    table = Table(samples, names=["Y", "Z"])

    table2 = table.marginal("Y")
    assert all(compare(table2.names, ["Z"]))

    table2 = table.marginal("Z")
    assert all(compare(table2.names, ["Y"]))

    # Three levels dist.
    samples = {
        ("a", "x", 1): 4,
        ("a", "x", 2): 4,
        ("a", "y", 1): 6,
        ("a", "y", 2): 6,
        ("b", "x", 1): 8,
        ("b", "x", 2): 8,
        ("b", "y", 1): 10,
        ("b", "y", 2): 10,
    }

    table = Table(samples)

    table2 = table.marginal("X1")
    assert all(compare(table2.names, ["X2", "X3"]))

    table2 = table.marginal("X2")
    assert all(compare(table2.names, ["X1", "X3"]))

    table2 = table.marginal("X3")
    assert all(compare(table2.names, ["X1", "X2"]))

    table2 = table.marginal("X1", "X3")
    assert all(compare(table2.names, ["X2"]))

    table2 = table.marginal("X2", "X3")
    assert all(compare(table2.names, ["X1"]))

    #
    table = Table(samples, names=["Y", "Z", "W"])

    table2 = table.marginal("Y")
    assert all(compare(table2.names, ["Z", "W"]))

    table2 = table.marginal("Z")
    assert all(compare(table2.names, ["Y", "W"]))

    table2 = table.marginal("W")
    assert all(compare(table2.names, ["Y", "Z"]))

    table2 = table.marginal("Y", "W")
    assert all(compare(table2.names, ["Z"]))

    table2 = table.marginal("Z", "W")
    assert all(compare(table2.names, ["Y"]))
Ejemplo n.º 18
0
def test_marginal_if_table_of_table():
    sample_1 = {
        ("a", "x", 1, 33): 1,
        ("a", "x", 2, 33): 2,
        ("a", "x", 1, 44): 3,
        ("a", "x", 2, 44): 4,
        ("a", "y", 1, 33): 5,
        ("a", "y", 2, 33): 6,
        ("a", "y", 1, 44): 7,
        ("a", "y", 2, 44): 8,
        ("b", "x", 1, 33): 9,
        ("b", "x", 2, 33): 10,
        ("b", "x", 1, 44): 11,
        ("b", "x", 2, 44): 12,
        ("b", "y", 1, 33): 13,
        # ("b", "y", 2, 33): 14,
        ("b", "y", 1, 44): 15,
        ("b", "y", 2, 44): 16,
    }

    table1 = Table(sample_1, names=["X1", "X2", "X3", "X4"])
    con_1 = table1.condition_on("X1")
    with pytest.raises(ValueError):
        con_1.marginal("X1")

    marginal_1 = con_1.marginal("X2")
    assert all(compare(marginal_1.names, ["X1"]))
    assert all(compare(marginal_1.children_names, ["X3", "X4"]))
    assert marginal_1["a"][1, 33] == approx((1 + 5) / 36)
    assert marginal_1["a"][1, 44] == approx((3 + 7) / 36)
    assert marginal_1["a"][2, 33] == approx((2 + 6) / 36)
    assert marginal_1["a"][2, 44] == approx((4 + 8) / 36)
    assert marginal_1["b"][1, 33] == approx((9 + 13) / 86)
    assert marginal_1["b"][1, 44] == approx((11 + 15) / 86)

    con_1 = table1.condition_on("X1", normalise=False)
    marginal_1 = con_1.marginal("X2", normalise=False)
    assert marginal_1["a"][1, 33] == (1 + 5)
    assert marginal_1["a"][1, 44] == (3 + 7)
    assert marginal_1["a"][2, 33] == (2 + 6)
    assert marginal_1["a"][2, 44] == (4 + 8)
    assert marginal_1["b"][1, 33] == (9 + 13)
    assert marginal_1["b"][1, 44] == (11 + 15)

    marginal_1 = con_1.marginal("X2", "X3")
    assert all(compare(marginal_1.names, ["X1"]))
    assert all(compare(marginal_1.children_names, ["X4"]))
    assert marginal_1["a"][33] == approx((1 + 5 + 2 + 6) / 36)
    assert marginal_1["a"][44] == approx((3 + 7 + 4 + 8) / 36)
    assert marginal_1["b"][33] == approx((9 + 10 + 13) / 86)
    assert marginal_1["b"][44] == approx((11 + 12 + 15 + 16) / 86)

    marginal_1 = con_1.marginal("X2", "X4")
    assert all(compare(marginal_1.names, ["X1"]))
    assert all(compare(marginal_1.children_names, ["X3"]))
    assert marginal_1["a"][1] == approx((1 + 3 + 5 + 7) / 36)
    assert marginal_1["a"][2] == approx((2 + 4 + 6 + 8) / 36)
    assert marginal_1["b"][1] == approx((9 + 11 + 13 + 15) / 86)
    assert marginal_1["b"][2] == approx((10 + 12 + 16) / 86)

    con_2 = table1.condition_on("X1", "X3")
    with pytest.raises(ValueError):
        con_2.marginal("X1")

    with pytest.raises(ValueError):
        con_2.marginal("X3")

    with pytest.raises(ValueError):
        con_2.marginal("X2", "X4")

    marginal_2 = con_2.marginal("X2")
    assert all(compare(marginal_2.names, ["X1", "X3"]))
    assert all(compare(marginal_2.children_names, ["X4"]))
    assert marginal_2["a", 1][33] == approx((1 + 5) / 16)
    assert marginal_2["a", 1][44] == approx((3 + 7) / 16)
    assert marginal_2["a", 2][44] == approx((4 + 8) / 20)
    assert marginal_2["b", 1][33] == approx((9 + 13) / 48)
    assert marginal_2["b", 2][33] == approx(10 / 38)
    assert marginal_2["b", 2][44] == approx((12 + 16) / 38)

    con_3 = table1.condition_on("X1", "X3", "X4")
    with pytest.raises(ValueError):
        con_3.marginal("X2")
Ejemplo n.º 19
0
def test_add_table():
    samples = {"a": 3, "b": 4, "c": 5}
    table1 = Table(samples, names=["X1"])
    table2 = Table(samples, names=["X1"])
    table3 = table1 + table2
    assert table3["a"] == 2 * 3
    assert table3["b"] == 2 * 4
    assert table3["c"] == 2 * 5

    samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
    table1 = Table(samples, names=["X1", "X2"])
    table2 = Table(samples, names=["X1", "X2"])
    table3 = table1 + table2
    assert table3["a", "x"] == 2 * 4
    assert table3["a", "y"] == 2 * 4
    assert table3["b", "x"] == 2 * 6
    assert table3["b", "y"] == 2 * 6

    sample_2 = {
        ("a", "x", 1, 33): 1,
        ("a", "x", 2, 33): 2,
        ("a", "x", 1, 44): 3,
        ("a", "x", 2, 44): 4,
        ("a", "y", 1, 33): 5,
        ("a", "y", 2, 33): 6,
        ("a", "y", 1, 44): 7,
        ("a", "y", 2, 44): 8,
        ("b", "x", 1, 33): 9,
        ("b", "x", 2, 33): 10,
        ("b", "x", 1, 44): 11,
        ("b", "x", 2, 44): 12,
        ("b", "y", 1, 33): 13,
        # ("b", "y", 2, 33): 14,
        ("b", "y", 1, 44): 15,
        ("b", "y", 2, 44): 16,
    }

    table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"])
    table2 = Table(sample_2, names=["X1", "X2", "X3", "X4"])
    table3 = table1 + table2
    assert table3["a", "x", 1, 33] == 2 * 1
    assert table3["a", "y", 2, 33] == 2 * 6
    assert table3["b", "x", 2, 44] == 2 * 12
    assert table3[("b", "y", 2, 44)] == 2 * 16

    con_1 = table1.condition_on("X1")
    con_2 = table2.condition_on("X1")
    table3 = con_1 + con_2
    assert table3["a"]["x", 1, 33] == (2 * 1) / 36
    assert table3["a"]["x", 1, 44] == (2 * 3) / 36
    assert table3["b"]["x", 1, 33] == (2 * 9) / 86
    assert table3["b"]["x", 1, 44] == (2 * 11) / 86

    con_1 = table1.condition_on("X1", "X3")
    con_2 = table2.condition_on("X1", "X3")
    table3 = con_1 + con_2
    assert table3["a", 1]["x", 33] == (2 * 1) / 16
    assert table3["a", 2]["y", 44] == (2 * 8) / 20
    assert table3["b", 1]["x", 44] == (2 * 11) / 48
    assert table3["b", 2]["y", 33] is None
    assert table3["b", 2]["y", 44] == (2 * 16) / 38
Ejemplo n.º 20
0
def test_statistical_independence_table():
    # P(x,y,z) = P(x)P(y)P(z)
    # to check that, first, create a joint dist. by product
    # then marginalis and multiply again. The final must be equal
    # the joint
    # Note: the multi-variable distributions must be statistically
    #       independent
    s_1 = {
        ("x", 1): 1 * 6,
        ("x", 2): 1 * 4,
        ("y", 1): 9 * 6,
        ("y", 2): 9 * 4,
    }

    table1 = Table(s_1, names=["X1", "X2"])

    s_2 = {
        (1, "high", "under", "x"): 4 * 3 * 1 * 1,
        (1, "high", "normal", "x"): 4 * 3 * 2 * 1,
        (1, "high", "over", "x"): 4 * 3 * 3 * 1,
        (1, "high", "obese", "x"): 4 * 3 * 4 * 1,
        (1, "low", "under", "x"): 4 * 2 * 1 * 1,
        (1, "low", "normal", "x"): 4 * 2 * 2 * 1,
        (1, "low", "over", "x"): 4 * 2 * 3 * 1,
        (1, "low", "obese", "x"): 4 * 2 * 4 * 1,
        (2, "high", "under", "x"): 2 * 3 * 1 * 1,
        (2, "high", "normal", "x"): 2 * 3 * 2 * 1,
        (2, "high", "over", "x"): 2 * 3 * 3 * 1,
        (2, "high", "obese", "x"): 2 * 3 * 4 * 1,
        (2, "low", "under", "x"): 2 * 2 * 1 * 1,
        (2, "low", "normal", "x"): 2 * 2 * 2 * 1,
        (2, "low", "over", "x"): 2 * 2 * 3 * 1,
        (2, "low", "obese", "x"): 2 * 2 * 4 * 1,
        (1, "high", "under", "y"): 4 * 3 * 1 * 3,
        (1, "high", "normal", "y"): 4 * 3 * 2 * 3,
        (1, "high", "over", "y"): 4 * 3 * 3 * 3,
        (1, "high", "obese", "y"): 4 * 3 * 4 * 3,
        (1, "low", "under", "y"): 4 * 2 * 1 * 3,
        (1, "low", "normal", "y"): 4 * 2 * 2 * 3,
        (1, "low", "over", "y"): 4 * 2 * 3 * 3,
        (1, "low", "obese", "y"): 4 * 2 * 4 * 3,
        (2, "high", "under", "y"): 2 * 3 * 1 * 3,
        (2, "high", "normal", "y"): 2 * 3 * 2 * 3,
        (2, "high", "over", "y"): 2 * 3 * 3 * 3,
        (2, "high", "obese", "y"): 2 * 3 * 4 * 3,
        (2, "low", "under", "y"): 2 * 2 * 1 * 3,
        (2, "low", "normal", "y"): 2 * 2 * 2 * 3,
        (2, "low", "over", "y"): 2 * 2 * 3 * 3,
        (2, "low", "obese", "y"): 2 * 2 * 4 * 3,
    }
    table2 = Table(s_2, names=["Y1", "Y2", "Y3", "Y4"])
    single_table3 = Table({11: 2, 22: 4, 33: 3}, names=["Z"])

    joint_dist = table1 * table2 * single_table3

    marginals = []
    for name in joint_dist.names:
        names_except_one = list(set(joint_dist.names) - {name})
        marginal = joint_dist.marginal(*names_except_one)
        marginals.append(marginal)

    joint_dist2 = np.product(marginals)

    # Normalise Both
    t1 = sum(joint_dist.values())
    t2 = sum(joint_dist2.values())

    joint_dist = Table({k: v / t1
                        for k, v in joint_dist.items()}, joint_dist.names)

    joint_dist2 = Table({k: v / t2
                         for k, v in joint_dist2.items()}, joint_dist2.names)

    for k1 in joint_dist:
        assert joint_dist[k1] == approx(joint_dist2[k1], abs=1e-16)
Ejemplo n.º 21
0
def test_conditional_on_exception_table():
    #
    with pytest.raises(ValueError):
        table1 = Table({"a": 3, "b": 4, "c": 5}, names=["X1"])
        table1.condition_on("X1")

    with pytest.raises(ValueError):
        table1 = Table(
            {
                ("a", "x"): 4,
                ("a", "y"): 4,
                ("b", "x"): 6,
                ("b", "y"): 6
            },
            names=["X1", "X2"],
        )
        table1.condition_on("X1", "X2")

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        table1.condition_on("X1", "X2", "X3", "X4")

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1")
        con_1.condition_on("X1")

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1", "X2")
        con_1.condition_on("X1", "X2")

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1", "X2")
        con_1.condition_on("X2", "X1")

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1", "X2", "X3")
        con_1.condition_on("X1", "X2", "X3")

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X3", "X2", "X2")
        con_1.condition_on("X1", "X3", "X1")
Ejemplo n.º 22
0
def test_reduce_exception_table():
    with pytest.raises(ValueError):
        table1 = Table({"a": 3, "b": 4, "c": 5}, names=["X1"])
        table1.reduce(X1="a")

    with pytest.raises(ValueError):
        table1 = Table(
            {
                ("a", "x"): 4,
                ("a", "y"): 4,
                ("b", "x"): 6,
                ("b", "y"): 6
            },
            names=["X1", "X2"],
        )
        table1.reduce(X1="a", X2="y")

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        table1.reduce(X1="a", X2="x", X3=1, X4=44)

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1")
        con_1.reduce(X2="x", X3=1, X4=44)

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X2")
        con_1.reduce(X1="a", X3=1, X4=44)

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X3")
        con_1.reduce(X1="a", X2="x", X4=44)

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1", "X2")
        con_1.reduce(X3=1, X4=44)

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1", "X3")
        con_1.reduce(X2="x", X4=44)

    with pytest.raises(ValueError):
        table1 = Table(samples, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1", "X3", "X4")
        con_1.reduce(X2="x")
Ejemplo n.º 23
0
def test_product_exceptions_table():
    table = Table(sample_1)
    with pytest.raises(ValueError):
        table *= "ttt"
Ejemplo n.º 24
0
def test_marginals_table():
    # Single RV dist.
    with pytest.raises(ValueError):
        table = Table({"A": 2, "B": 3, "C": 4})
        table.marginal("X1")

    # Two levels dist.
    samples = {(1, 1): 4, (1, 2): 4, (2, 1): 6, (2, 2): 6}
    table = Table(samples)
    table2 = table.marginal("X1")
    assert all(compare(table2.keys(), [(1, ), (2, )]))
    assert table2[1] == 10 / 20
    assert table2[2] == 10 / 20

    table2 = table.marginal("X2")
    assert all(compare(table2.keys(), [(1, ), (2, )]))
    assert table2[1] == 8 / 20
    assert table2[2] == 12 / 20

    samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
    table = Table(samples)
    table2 = table.marginal("X1")
    assert all(compare(table2.keys(), [("x", ), ("y", )]))
    assert table2["x"] == 10 / 20
    assert table2["y"] == 10 / 20

    table2 = table.marginal("X1")
    assert all(compare(table2.keys(), [("x", ), ("y", )]))
    assert table2["x"] == 10 / 20
    assert table2["y"] == 10 / 20

    table2 = table.marginal("X2")
    assert all(compare(table2.keys(), [("a", ), ("b", )]))
    assert table2["a"] == 8 / 20
    assert table2["b"] == 12 / 20

    # Three levels dist.
    samples = {
        ("a", "x", 1): 4,
        ("a", "x", 2): 4,
        ("a", "y", 1): 6,
        ("a", "y", 2): 6,
        ("b", "x", 1): 8,
        ("b", "x", 2): 8,
        ("b", "y", 1): 10,
        ("b", "y", 2): 10,
    }
    table = Table(samples)
    table2 = table.marginal("X1")
    assert all(compare(table2.keys(), [("x", 1), ("x", 2), ("y", 1),
                                       ("y", 2)]))
    assert table2[("x", 1)] == 12 / 56
    assert table2[("x", 2)] == 12 / 56
    assert table2[("y", 1)] == 16 / 56
    assert table2[("y", 2)] == 16 / 56

    table2 = table.marginal("X2")
    assert all(compare(table2.keys(), [("a", 1), ("a", 2), ("b", 1),
                                       ("b", 2)]))
    assert table2[("a", 1)] == 10 / 56
    assert table2[("a", 2)] == 10 / 56
    assert table2[("b", 1)] == 18 / 56
    assert table2[("b", 2)] == 18 / 56

    table2 = table.marginal("X3")
    assert all(
        compare(table2.keys(), [("a", "x"), ("a", "y"), ("b", "x"),
                                ("b", "y")]))
    assert table2[("a", "x")] == 8 / 56
    assert table2[("a", "y")] == 12 / 56
    assert table2[("b", "x")] == 16 / 56
    assert table2[("b", "y")] == 20 / 56

    table2 = table.marginal("X1", "X2")
    assert all(compare(table2.keys(), [(1, ), (2, )]))
    assert table2[1] == 28 / 56
    assert table2[2] == 28 / 56

    table2 = table.marginal("X1", "X3")
    assert all(compare(table2.keys(), [("x", ), ("y", )]))
    assert table2["x"] == 24 / 56
    assert table2["y"] == 32 / 56

    table2 = table.marginal("X2", "X3")
    assert all(compare(table2.keys(), [("a", ), ("b", )]))
    assert table2["a"] == 20 / 56
    assert table2["b"] == 36 / 56

    # Four levels dist.
    samples = {
        ("a", "x", 1, 33): 1,
        ("a", "x", 2, 33): 2,
        ("a", "x", 1, 44): 3,
        ("a", "x", 2, 44): 4,
        ("a", "y", 1, 33): 5,
        ("a", "y", 2, 33): 6,
        ("a", "y", 1, 44): 7,
        ("a", "y", 2, 44): 8,
        ("b", "x", 1, 33): 9,
        ("b", "x", 2, 33): 10,
        ("b", "x", 1, 44): 11,
        ("b", "x", 2, 44): 12,
        ("b", "y", 1, 33): 13,
        ("b", "y", 2, 33): 14,
        ("b", "y", 1, 44): 15,
        ("b", "y", 2, 44): 16,
    }
    table = Table(samples)
    table2 = table.marginal("X3")
    assert all(
        compare(
            table2.keys(),
            [
                ("a", "x", 33),
                ("a", "x", 44),
                ("a", "y", 33),
                ("a", "y", 44),
                ("b", "x", 33),
                ("b", "x", 44),
                ("b", "y", 33),
                ("b", "y", 44),
            ],
        ))
    assert table2[("a", "x", 33)] == 3 / 136
    assert table2[("a", "x", 44)] == 7 / 136
    assert table2[("a", "y", 33)] == 11 / 136
    assert table2[("a", "y", 44)] == 15 / 136
    assert table2[("b", "x", 33)] == 19 / 136
    assert table2[("b", "x", 44)] == 23 / 136
    assert table2[("b", "y", 33)] == 27 / 136
    assert table2[("b", "y", 44)] == 31 / 136

    table2 = table.marginal("X4")
    assert all(
        compare(
            table2.keys(),
            [
                ("a", "x", 1),
                ("a", "x", 2),
                ("a", "y", 1),
                ("a", "y", 2),
                ("b", "x", 1),
                ("b", "x", 2),
                ("b", "y", 1),
                ("b", "y", 2),
            ],
        ))
    assert table2[("a", "x", 1)] == 4 / 136
    assert table2[("a", "x", 2)] == 6 / 136
    assert table2[("a", "y", 1)] == 12 / 136
    assert table2[("a", "y", 2)] == 14 / 136
    assert table2[("b", "x", 1)] == 20 / 136
    assert table2[("b", "x", 2)] == 22 / 136
    assert table2[("b", "y", 1)] == 28 / 136
    assert table2[("b", "y", 2)] == 30 / 136

    table2 = table.marginal("X1", "X4")
    assert all(compare(table2.keys(), [("x", 1), ("x", 2), ("y", 1),
                                       ("y", 2)]))
    assert table2[("x", 1)] == 24 / 136
    assert table2[("x", 2)] == 28 / 136
    assert table2[("y", 1)] == 40 / 136
    assert table2[("y", 2)] == 44 / 136

    table2 = table.marginal("X1", "X2", "X4")
    assert all(compare(table2.keys(), [(1, ), (2, )]))
    assert table2[1] == 64 / 136
    assert table2[2] == 72 / 136

    # marginalize two times
    table2 = table.marginal("X1", "X4")
    table3 = table2.marginal("X2")
    assert all(compare(table3.keys(), [(1, ), (2, )]))
    assert table3[1] == 64 / 136
    assert table3[2] == 72 / 136

    # marginalize three times
    table2 = table.marginal("X4")
    table3 = table2.marginal("X3")
    table4 = table3.marginal("X2")
    assert all(compare(table4.keys(), [("a", ), ("b", )]))
    assert table4["a"] == 36 / 136
    assert table4["b"] == 100 / 136
Ejemplo n.º 25
0
def test_marginal_by_name_table():
    # Four levels dist.
    samples = {
        ("a", "x", 1, 33): 1,
        ("a", "x", 2, 33): 2,
        ("a", "x", 1, 44): 3,
        ("a", "x", 2, 44): 4,
        ("a", "y", 1, 33): 5,
        ("a", "y", 2, 33): 6,
        ("a", "y", 1, 44): 7,
        ("a", "y", 2, 44): 8,
        ("b", "x", 1, 33): 9,
        ("b", "x", 2, 33): 10,
        ("b", "x", 1, 44): 11,
        ("b", "x", 2, 44): 12,
        ("b", "y", 1, 33): 13,
        ("b", "y", 2, 33): 14,
        ("b", "y", 1, 44): 15,
        ("b", "y", 2, 44): 16,
    }
    table = Table(samples, names=["Age", "Sex", "Edu", "Etn"])
    table2 = table.marginal("Edu")
    assert all(
        compare(
            table2.keys(),
            [
                ("a", "x", 33),
                ("a", "x", 44),
                ("a", "y", 33),
                ("a", "y", 44),
                ("b", "x", 33),
                ("b", "x", 44),
                ("b", "y", 33),
                ("b", "y", 44),
            ],
        ))
    assert table2[("a", "x", 33)] == 3 / 136
    assert table2[("a", "x", 44)] == 7 / 136
    assert table2[("a", "y", 33)] == 11 / 136
    assert table2[("a", "y", 44)] == 15 / 136
    assert table2[("b", "x", 33)] == 19 / 136
    assert table2[("b", "x", 44)] == 23 / 136
    assert table2[("b", "y", 33)] == 27 / 136
    assert table2[("b", "y", 44)] == 31 / 136

    table2 = table.marginal("Edu", normalise=False)
    assert table2[("a", "x", 33)] == 3
    assert table2[("a", "x", 44)] == 7
    assert table2[("a", "y", 33)] == 11
    assert table2[("a", "y", 44)] == 15
    assert table2[("b", "x", 33)] == 19
    assert table2[("b", "x", 44)] == 23
    assert table2[("b", "y", 33)] == 27
    assert table2[("b", "y", 44)] == 31

    table2 = table.marginal("Etn")
    assert all(
        compare(
            table2.keys(),
            [
                ("a", "x", 1),
                ("a", "x", 2),
                ("a", "y", 1),
                ("a", "y", 2),
                ("b", "x", 1),
                ("b", "x", 2),
                ("b", "y", 1),
                ("b", "y", 2),
            ],
        ))
    assert table2[("a", "x", 1)] == 4 / 136
    assert table2[("a", "x", 2)] == 6 / 136
    assert table2[("a", "y", 1)] == 12 / 136
    assert table2[("a", "y", 2)] == 14 / 136
    assert table2[("b", "x", 1)] == 20 / 136
    assert table2[("b", "x", 2)] == 22 / 136
    assert table2[("b", "y", 1)] == 28 / 136
    assert table2[("b", "y", 2)] == 30 / 136

    table2 = table.marginal("Age", "Etn")
    assert all(compare(table2.keys(), [("x", 1), ("x", 2), ("y", 1),
                                       ("y", 2)]))
    assert table2[("x", 1)] == 24 / 136
    assert table2[("x", 2)] == 28 / 136
    assert table2[("y", 1)] == 40 / 136
    assert table2[("y", 2)] == 44 / 136

    table2 = table.marginal("Age", "Sex", "Etn")
    assert all(compare(table2.keys(), [(1, ), (2, )]))
    assert table2[1] == 64 / 136
    assert table2[2] == 72 / 136

    # marginalize two times
    table2 = table.marginal("Age", "Etn")
    table3 = table2.marginal("Sex")
    assert all(compare(table3.keys(), [(1, ), (2, )]))
    assert table3[1] == 64 / 136
    assert table3[2] == 72 / 136

    # marginalize three times
    table2 = table.marginal("Etn")
    table3 = table2.marginal("Edu")
    table4 = table3.marginal("Sex")
    assert all(compare(table4.keys(), [("a", ), ("b", )]))
    assert table4["a"] == 36 / 136
    assert table4["b"] == 100 / 136
Ejemplo n.º 26
0
def test_factoring_table():
    s_1 = {
        (1, "high", "under", "x"): 4 * 3 * 1 * 1,
        (1, "high", "normal", "x"): 4 * 3 * 2 * 1,
        (1, "high", "over", "x"): 4 * 3 * 3 * 1,
        (1, "high", "obese", "x"): 4 * 3 * 4 * 1,
        (1, "low", "under", "x"): 4 * 2 * 1 * 1,
        (1, "low", "normal", "x"): 4 * 2 * 2 * 1,
        (1, "low", "over", "x"): 4 * 2 * 3 * 1,
        (1, "low", "obese", "x"): 4 * 2 * 4 * 1,
        (2, "high", "under", "x"): 2 * 3 * 1 * 1,
        (2, "high", "normal", "x"): 2 * 3 * 2 * 1,
        (2, "high", "over", "x"): 2 * 3 * 3 * 1,
        (2, "high", "obese", "x"): 2 * 3 * 4 * 1,
        (2, "low", "under", "x"): 2 * 2 * 1 * 1,
        (2, "low", "normal", "x"): 2 * 2 * 2 * 1,
        (2, "low", "over", "x"): 2 * 2 * 3 * 1,
        (2, "low", "obese", "x"): 2 * 2 * 4 * 1,
        (1, "high", "under", "y"): 4 * 3 * 1 * 3,
        (1, "high", "normal", "y"): 4 * 3 * 2 * 3,
        (1, "high", "over", "y"): 4 * 3 * 3 * 3,
        (1, "high", "obese", "y"): 4 * 3 * 4 * 3,
        (1, "low", "under", "y"): 4 * 2 * 1 * 3,
        (1, "low", "normal", "y"): 4 * 2 * 2 * 3,
        (1, "low", "over", "y"): 4 * 2 * 3 * 3,
        (1, "low", "obese", "y"): 4 * 2 * 4 * 3,
        (2, "high", "under", "y"): 2 * 3 * 1 * 3,
        (2, "high", "normal", "y"): 2 * 3 * 2 * 3,
        (2, "high", "over", "y"): 2 * 3 * 3 * 3,
        (2, "high", "obese", "y"): 2 * 3 * 4 * 3,
        (2, "low", "under", "y"): 2 * 2 * 1 * 3,
        (2, "low", "normal", "y"): 2 * 2 * 2 * 3,
        (2, "low", "over", "y"): 2 * 2 * 3 * 3,
        (2, "low", "obese", "y"): 2 * 2 * 4 * 3,
    }

    def assert_all(table1, table2):
        for key1 in table1:
            key2_dict = table1.columns.named_key(key1)
            assert table1[key1] == approx(table2.get(**key2_dict))

    table1 = Table(s_1, names=["Y1", "Y2", "Y3", "Y4"])
    table1.normalise()
    #
    table2 = table1.marginal("Y2", "Y3", "Y4")
    #
    table3 = table1.condition_on("Y1")
    #
    table4 = table3 * table2
    assert_all(table4, table1)

    table5 = table2 * table3
    assert_all(table5, table1)
    # On two columns
    table6 = table1.marginal("Y3", "Y4")
    table7 = table1.condition_on("Y1", "Y2")

    table8 = table6 * table7
    assert_all(table8, table1)

    table9 = table7 * table6
    assert_all(table9, table1)
    # On Three columns
    table10 = table1.marginal("Y4")
    table11 = table1.condition_on("Y1", "Y2", "Y3")

    table12 = table10 * table11
    assert_all(table12, table1)

    table13 = table11 * table10
    assert_all(table13, table1)
Ejemplo n.º 27
0
def test_add_exception_table():
    # Wrong column name
    with pytest.raises(ValueError):
        samples = {"a": 3, "b": 4, "c": 5}
        table1 = Table(samples, names=["X1"])
        table2 = Table(samples, names=["X2"])
        table1 += table2

    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table1 = Table(samples, names=["X1", "X2"])
        table2 = Table(samples, names=["Y2", "Y2"])
        table1 += table2

    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table1 = Table(samples, names=["X1", "X2"])
        table2 = Table(samples, names=["X1", "Y2"])
        table1 += table2

    # wrong order in names
    with pytest.raises(ValueError):
        samples = {("a", "x"): 4, ("a", "y"): 4, ("b", "x"): 6, ("b", "y"): 6}
        table1 = Table(samples, names=["X1", "X2"])
        table2 = Table(samples, names=["X2", "X1"])
        table1 += table2

    sample_2 = {
        ("a", "x", 1, 33): 1,
        ("a", "x", 2, 33): 2,
        ("a", "x", 1, 44): 3,
        ("a", "x", 2, 44): 4,
        ("a", "y", 1, 33): 5,
        ("a", "y", 2, 33): 6,
        ("a", "y", 1, 44): 7,
        ("a", "y", 2, 44): 8,
        ("b", "x", 1, 33): 9,
        ("b", "x", 2, 33): 10,
        ("b", "x", 1, 44): 11,
        ("b", "x", 2, 44): 12,
        ("b", "y", 1, 33): 13,
        # ("b", "y", 2, 33): 14,
        ("b", "y", 1, 44): 15,
        ("b", "y", 2, 44): 16,
    }
    sample_3 = {
        ("a", 1, 33): 1,
        ("a", 2, 33): 2,
        ("a", 1, 44): 7,
        ("a", 2, 44): 8,
        ("b", 1, 33): 9,
        ("b", 2, 33): 10,
        ("b", 2, 44): 16,
    }
    # conditional
    with pytest.raises(ValueError):
        table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"])
        table2 = Table(sample_2, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1")
        con_2 = table2.condition_on("X2")
        con_1 += con_2

    with pytest.raises(ValueError):
        table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"])
        table2 = Table(sample_2, names=["X1", "X2", "X3", "X4"])
        con_1 = table1.condition_on("X1", "X2")
        con_2 = table2.condition_on("X2")
        con_1 += con_2

    with pytest.raises(ValueError):
        table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"])
        table2 = Table(sample_3, names=["X1", "X2", "X3"])
        con_1 = table1.condition_on("X1")
        con_2 = table2.condition_on("X1")
        con_1 += con_2

    with pytest.raises(ValueError):
        table1 = Table(sample_2, names=["X1", "X2", "X3", "X4"])
        table2 = Table(sample_3, names=["X1", "X2", "X3"])
        con_1 = table1.condition_on("X1", "X2")
        con_2 = table2.condition_on("X1", "X2")
        con_1 += con_2